Skip to content

Commit 4725957

Browse files
committed
restore support for '@jit foo(Reactant.to_rarray(rand(2)))' and add '@jit foo(foo(Reactant.to_rarray(rand(2))))'
1 parent 38e10fe commit 4725957

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

src/Compiler.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,13 +440,28 @@ end
440440

441441
#create expression for more complex expression than a call
442442
function wrapped_expression(expr::Expr)
443-
args = ExpressionExplorer.compute_symbols_state(expr).references
444-
args = filter(!is_a_module, args)
445-
args = tuple(collect(args)...)
443+
css = ExpressionExplorer.compute_symbols_state(expr)
444+
tracked_definitions = []
445+
tracked_names = []
446+
alter_expr = (e::Expr) -> begin
447+
for (i, arg) in enumerate(e.args)
448+
arg isa Expr && alter_expr(arg)
449+
is_tracking_call(arg) || continue
450+
name = gensym(:tracked)
451+
push!(tracked_definitions, arg)
452+
push!(tracked_names, name)
453+
e.args[i] = name
454+
end
455+
end
456+
alter_expr(expr)
457+
458+
free_args = collect(css.references)
459+
function_args = tuple([free_args; tracked_definitions]...)
460+
args = tuple([free_args; tracked_names]...)
446461
fname = gensym(:F)
447462

448463
return (
449-
Expr(:tuple, args...),
464+
Expr(:tuple, function_args...),
450465
quote
451466
($fname)($(args...)) = $expr
452467
end,
@@ -456,11 +471,18 @@ function wrapped_expression(expr::Expr)
456471
)
457472
end
458473

474+
function is_tracking_call(input)
475+
Meta.isexpr(input, :call) || return false
476+
function_name = (ExpressionExplorer.explore_funcdef!(input, ExpressionExplorer.ScopeState()))[1].parts[end]
477+
function_name in [:to_rarray, :ConcreteRNumber]
478+
end
479+
459480
#check if an expression need to be wrap in a closure
460481
function need_wrap(expr::Expr)::Bool
461482
for arg in expr.args
462483
arg isa Expr || continue
463484
Meta.isexpr(arg, :.) && continue
485+
is_tracking_call(arg) && continue
464486
return true
465487
end
466488
return false

0 commit comments

Comments
 (0)