Skip to content

Commit ddb03bf

Browse files
jumerckxgithub-actions[bot]Pangoraw
authored
@trace if return nothing in branches (#1311)
* return nothing if the value if the result value in an if-statement is not required. * bump * don't create missing traced value if the variable exists before the if statement already. * actually fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * test * add test --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Paul Berg <[email protected]>
1 parent 0f0ded5 commit ddb03bf

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,10 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
414414
$(store_last_line) = $(true_last_line)
415415
end
416416
else
417-
expr.args[2]
417+
quote
418+
$(expr.args[2])
419+
nothing # explicitly return nothing to prevent branches from returning different types
420+
end
418421
end
419422

420423
true_branch_symbols = ExpressionExplorer.compute_symbols_state(true_block)
@@ -457,7 +460,10 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
457460
$(store_last_line) = $(false_last_line)
458461
end
459462
else
460-
else_block
463+
quote
464+
$else_block
465+
nothing # explicitly return nothing to prevent branches from returning different types
466+
end
461467
end
462468

463469
false_branch_symbols = ExpressionExplorer.compute_symbols_state(false_block)
@@ -473,10 +479,12 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
473479

474480
all_vars = all_input_vars all_output_vars
475481

476-
non_existant_true_branch_vars = setdiff(all_output_vars, all_true_branch_vars)
482+
non_existent_true_branch_vars = setdiff(
483+
all_output_vars, all_true_branch_vars, all_input_vars
484+
)
477485
true_branch_extras = Expr(
478486
:block,
479-
[:($(var) = $(MissingTracedValue)()) for var in non_existant_true_branch_vars]...,
487+
[:($(var) = $(MissingTracedValue)()) for var in non_existent_true_branch_vars]...,
480488
)
481489

482490
true_branch_fn = :(($(all_input_vars...),) -> begin
@@ -489,12 +497,12 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
489497
)
490498
true_branch_fn = :($(true_branch_fn_name) = $(true_branch_fn))
491499

492-
non_existant_false_branch_vars = setdiff(
493-
setdiff(all_output_vars, all_false_branch_vars), all_input_vars
500+
non_existent_false_branch_vars = setdiff(
501+
all_output_vars, all_false_branch_vars, all_input_vars
494502
)
495503
false_branch_extras = Expr(
496504
:block,
497-
[:($(var) = $(MissingTracedValue)()) for var in non_existant_false_branch_vars]...,
505+
[:($(var) = $(MissingTracedValue)()) for var in non_existent_false_branch_vars]...,
498506
)
499507

500508
false_branch_fn = :(($(all_input_vars...),) -> begin

test/control_flow.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,3 +934,36 @@ end
934934

935935
@test h_ra h
936936
end
937+
938+
function different_branch_returns(cond, a, b)
939+
@trace if cond
940+
a .= sin.(a)
941+
nothing
942+
else
943+
b .= sin.(b)
944+
nothing
945+
end
946+
return a, b
947+
end
948+
949+
@testset "one branch mutates variable" begin
950+
cond = true
951+
a = 3 .* ones(Float32, 2, 3)
952+
b = 4 .* ones(Float64, 2, 3)
953+
954+
cond_ra = ConcreteRNumber{Bool}(cond)
955+
a_ra = Reactant.to_rarray(a)
956+
b_ra = Reactant.to_rarray(b)
957+
958+
result_ra = @jit(different_branch_returns(cond_ra, a_ra, b_ra))
959+
result = different_branch_returns(cond, a, b)
960+
961+
@test result_ra[1] == a_ra
962+
@test result_ra[2] == b_ra
963+
964+
@test result[1] == a
965+
@test result[2] == b
966+
967+
@test a_ra a
968+
@test b_ra b
969+
end

0 commit comments

Comments
 (0)