Skip to content

Commit 632ef06

Browse files
Fix verify_arg_names when generated function has sizeof != 0 (#1321)
* fix `verify_arg_names` when generated function has sizeof != 0 * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * bump * Update lib/ReactantCore/src/ReactantCore.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * woopsie * test nested loop --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 80499b0 commit 632ef06

File tree

5 files changed

+45
-12
lines changed

5 files changed

+45
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Preferences = "1.4"
8686
PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
89-
ReactantCore = "0.1.9"
89+
ReactantCore = "0.1.10"
9090
Reactant_jll = "0.0.187"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.9"
4+
version = "0.1.10"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,14 @@ function trace_for(mod, expr; track_numbers)
215215
end
216216
) for (s, ref) in zip(external_syms, ref_syms)
217217
]
218+
body_fn_sym = gensym(:body_fn)
219+
cond_fn_sym = gensym(:cond_fn)
220+
args_sym = gensym(:args)
221+
verify_arg_names_sym = gensym(:verify_arg_names)
218222

219223
reactant_code_block = quote
220-
let args = $(args_init)
221-
cond_fn =
224+
let $(args_sym) = $(args_init)
225+
$(cond_fn_sym) =
222226
$(arg_syms) -> begin
223227
$(to_locals...)
224228
local num_iters = div($limit - $start, $step, RoundDown)
@@ -227,7 +231,7 @@ function trace_for(mod, expr; track_numbers)
227231
)
228232
$counter[] < num_iters + 1
229233
end
230-
body_fn =
234+
$(body_fn_sym) =
231235
$(arg_syms) -> begin
232236
local step_ = $step
233237
local start_ = $start
@@ -238,13 +242,17 @@ function trace_for(mod, expr; track_numbers)
238242
$counter[].mlir_data = ($counter[] + 1).mlir_data
239243
nothing
240244
end
241-
245+
$(verify_arg_names_sym) = if sizeof($(cond_fn_sym)) != 0
246+
(Symbol($cond_fn_sym), $(QuoteNode.(args_names.args)...))
247+
else
248+
($(QuoteNode.(args_names.args)...),)
249+
end
242250
$(ReactantCore).traced_while(
243-
cond_fn,
244-
body_fn,
245-
args;
251+
$(cond_fn_sym),
252+
$(body_fn_sym),
253+
$(args_sym);
246254
track_numbers=$(track_numbers),
247-
verify_arg_names=$(QuoteNode(args_names)),
255+
verify_arg_names=$(verify_arg_names_sym),
248256
)
249257
end
250258
end

src/TracedUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ function prepare_mlir_fn_args(
495495
stridx = if verify_arg_names isa Nothing
496496
"arg" * string(path[2])
497497
else
498-
string(verify_arg_names.args[path[2]])
498+
string(verify_arg_names[path[2]])
499499
end
500500
aval = args[path[2]]
501501
for (cidx, idx) in enumerate(path[3:end])
@@ -683,7 +683,7 @@ function finalize_mlir_fn(
683683
for (errs, prev, post) in ((err1, resis, argis), (err2, argis, resis))
684684
conflicts = setdiff(prev, post)
685685
for conflict in conflicts
686-
stridx = string(verify_arg_names.args[conflict[1]])
686+
stridx = string(verify_arg_names[conflict[1]])
687687
aval = args[conflict[1]]
688688
for (cidx, idx) in enumerate(Base.tail(conflict))
689689
if aval isa Array || aval isa Dict

test/control_flow.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,3 +852,28 @@ end
852852

853853
@test @jit(loop_batched(x_ra)) loop_batched(x)
854854
end
855+
856+
function loop!(h_mat::AbstractMatrix, η_mat::AbstractMatrix, H_mat::AbstractMatrix)
857+
m,n = size(h_mat)
858+
@inbounds @trace for i in 1:m
859+
@trace for j in 1:n
860+
@allowscalar h_mat[i,j] = η_mat[i,j] + H_mat[i,j]
861+
end
862+
end
863+
end
864+
865+
@testset "loop! with nested traced loops and scalar setindex!" begin
866+
h = zeros(Float64, 2, 3)
867+
η = [1.0 2.0 3.0; 4.0 5.0 6.0]
868+
H = [0.5 1.5 2.5; 3.5 4.5 5.5]
869+
870+
h_ra = Reactant.to_rarray(h)
871+
η_ra = Reactant.to_rarray(η)
872+
H_ra = Reactant.to_rarray(H)
873+
874+
@jit loop!(h_ra, η_ra, H_ra)
875+
876+
loop!(h, η, H)
877+
878+
@test h_ra h
879+
end

0 commit comments

Comments
 (0)