Skip to content

Commit 4cbad40

Browse files
authored
cf: extract condition to a name (#801)
* cf: extract condition to a name Closes #768. * handle if with return
1 parent c9f4586 commit 4cbad40

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,11 @@ function trace_if_with_returns(mod, expr)
225225
new_expr, _, all_check_vars = trace_if(
226226
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
227227
)
228+
cond_name = first(all_check_vars)
229+
original_cond = expr.args[2].args[1]
230+
expr.args[2].args[1] = cond_name
228231
return quote
232+
$(cond_name) = $(original_cond)
229233
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
230234
$(new_expr)
231235
else
@@ -353,29 +357,42 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
353357
)
354358
false_branch_fn = :($(false_branch_fn_name) = $(false_branch_fn))
355359

360+
cond_name = gensym(:cond)
361+
356362
reactant_code_block = quote
357363
$(true_branch_fn)
358364
$(false_branch_fn)
359365
($(all_output_vars...),) = $(traced_if)(
360-
$(cond_expr),
366+
$(cond_name),
361367
$(true_branch_fn_name),
362368
$(false_branch_fn_name),
363369
($(all_input_vars...),),
364370
)
365371
end
366372

367-
all_check_vars = [all_input_vars..., condition_vars...]
373+
non_reactant_code_block = Expr(:if, cond_name, original_expr.args[2])
374+
if length(original_expr.args) > 2 # has else block
375+
append!(non_reactant_code_block.args, original_expr.args[3:end])
376+
end
377+
378+
all_check_vars = [cond_name, all_input_vars..., condition_vars...]
368379
unique!(all_check_vars)
369380

370381
depth > 0 && return (
371-
reactant_code_block, (true_branch_fn_name, false_branch_fn_name), all_check_vars
382+
quote
383+
$(cond_name) = $(cond_expr)
384+
$(reactant_code_block)
385+
end,
386+
(true_branch_fn_name, false_branch_fn_name),
387+
all_check_vars,
372388
)
373389

374390
return quote
391+
$(cond_name) = $(cond_expr)
375392
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
376393
$(reactant_code_block)
377394
else
378-
$(original_expr)
395+
$(non_reactant_code_block)
379396
end
380397
end
381398
end

test/control_flow.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,8 +749,7 @@ mutable struct TestSimulation{C,I,B}
749749
end
750750

751751
function step!(sim)
752-
cond = sim.clock.iteration >= sim.stop_iteration
753-
@trace if cond
752+
@trace if sim.clock.iteration >= sim.stop_iteration
754753
sim.running = false
755754
else
756755
sim.clock.iteration += 1 # time step

0 commit comments

Comments
 (0)