Skip to content

Commit 01fdef6

Browse files
authored
@trace: Promote loop variables to traced numbers (#1381)
* `@trace`: Promote loop variables to traced numbers * fmt * rm * Update ReactantCore.jl
1 parent 56c07f6 commit 01fdef6

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

docs/src/.vitepress/config.mts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ export default defineConfig({
155155
{ text: "Profiling", link: "/tutorials/profiling" },
156156
{ text: "Distributed", link: "/tutorials/multihost" },
157157
{ text: "Local build", link: "/tutorials/local-build" },
158+
{ text: "Control Flow", link: "/tutorials/control-flow" },
158159
],
159160
}
160161
],

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ function trace_while(expr; track_numbers, first_arg=nothing)
231231
$(cond_fn_sym),
232232
$(body_fn_sym),
233233
$(args_sym);
234-
track_numbers=$(track_numbers),
235-
verify_arg_names=$(verify_arg_names_sym),
234+
track_numbers=($(track_numbers)),
235+
verify_arg_names=($(verify_arg_names_sym)),
236236
)
237237
end
238238
end
@@ -289,10 +289,26 @@ function trace_for(expr; track_numbers)
289289
end
290290
end
291291

292-
quote
292+
return quote
293293
local $start_sym, $limit_sym, $step_sym
294294
$bounds_defs
295-
local $counter = 0
295+
296+
if $(within_compile)()
297+
$start_sym = Reactant.TracedUtils.promote_to(
298+
Reactant.TracedRNumber{Reactant.unwrapped_eltype(typeof($start_sym))},
299+
$start_sym,
300+
)
301+
$limit_sym = Reactant.TracedUtils.promote_to(
302+
Reactant.TracedRNumber{Reactant.unwrapped_eltype(typeof($limit_sym))},
303+
$limit_sym,
304+
)
305+
$step_sym = Reactant.TracedUtils.promote_to(
306+
Reactant.TracedRNumber{Reactant.unwrapped_eltype(typeof($step_sym))},
307+
$step_sym,
308+
)
309+
end
310+
311+
local $counter = zero($start_sym)
296312

297313
$(trace_while(
298314
Expr(
@@ -471,7 +487,7 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
471487
$(true_branch_fn_name),
472488
$(false_branch_fn_name),
473489
($(all_input_vars...),);
474-
track_numbers=$(track_numbers),
490+
track_numbers=($(track_numbers)),
475491
)
476492
end
477493

test/control_flow.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ function condition_with_structure(x)
461461
@trace if sum(y) > 0
462462
z = (; a=y, b=(y .- 1, y))
463463
else
464-
z = (; a=-y, b=(y, y .+ 1))
464+
z = (; a=(-y), b=(y, y .+ 1))
465465
end
466466
return z
467467
end
@@ -676,6 +676,27 @@ function while_convergence(x, y)
676676
return diff
677677
end
678678

679+
function for_no_track_numbers(x, n)
680+
@trace track_numbers = false for i in n:16
681+
x = x .+ 1
682+
end
683+
return x
684+
end
685+
686+
@testset "for: track_numbers=false" begin
687+
x = [1, 2, 3]
688+
x_ra = Reactant.to_rarray(x)
689+
690+
n = 12
691+
n_ra = Reactant.ConcreteRNumber(n)
692+
693+
# set optimize to only do enzyme-batch to prevent crash in opt
694+
for_no_track_numbers_ra = @compile optimize="enzyme-batch" for_no_track_numbers(
695+
x_ra, n_ra
696+
)
697+
for_no_track_numbers_ra(x_ra, n_ra) == for_no_track_numbers(x, n)
698+
end
699+
679700
@testset "while: convergence" begin
680701
x = [1.0, 10.0, 20.0]
681702
y = [0.0, -2.0, -3.0]

0 commit comments

Comments
 (0)