Skip to content

Commit 38e10fe

Browse files
committed
add support to complex expr in compile_call_expr
1 parent 66d6cfc commit 38e10fe

File tree

3 files changed

+67
-21
lines changed

3 files changed

+67
-21
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
99
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
1010
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1111
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
12+
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
1213
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1314
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -25,8 +26,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2526
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2627
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
2728

28-
[sources.ReactantCore]
29-
path = "lib/ReactantCore"
29+
[sources]
30+
ReactantCore = {path = "lib/ReactantCore"}
3031

3132
[extensions]
3233
ReactantAbstractFFTsExt = "AbstractFFTs"
@@ -43,6 +44,7 @@ CEnum = "0.4, 0.5"
4344
Downloads = "1.6"
4445
Enzyme = "0.13.21"
4546
EnzymeCore = "0.8.6, 0.8.7, 0.8.8"
47+
ExpressionExplorer = "1.1.0"
4648
GPUArraysCore = "0.1.6, 0.2"
4749
LinearAlgebra = "1.10"
4850
NNlib = "0.9.24"

src/Compiler.jl

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import ..Reactant:
1414
append_path,
1515
TracedType
1616

17+
using ExpressionExplorer
18+
1719
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
1820
@inline traced_getfield(
1921
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
@@ -432,6 +434,38 @@ macro jit(args...)
432434
#! format: on
433435
end
434436

437+
is_a_module(s::Symbol)::Bool = begin
438+
isdefined(@__MODULE__, s) && getproperty(@__MODULE__, s) isa Module
439+
end
440+
441+
#create expression for more complex expression than a call
442+
function wrapped_expression(expr::Expr)
443+
args = ExpressionExplorer.compute_symbols_state(expr).references
444+
args = filter(!is_a_module, args)
445+
args = tuple(collect(args)...)
446+
fname = gensym(:F)
447+
448+
return (
449+
Expr(:tuple, args...),
450+
quote
451+
($fname)($(args...)) = $expr
452+
end,
453+
quote
454+
$(fname)
455+
end,
456+
)
457+
end
458+
459+
#check if an expression need to be wrap in a closure
460+
function need_wrap(expr::Expr)::Bool
461+
for arg in expr.args
462+
arg isa Expr || continue
463+
Meta.isexpr(arg, :.) && continue
464+
return true
465+
end
466+
return false
467+
end
468+
435469
function compile_call_expr(mod, compiler, options, args...)
436470
while length(args) > 1
437471
option, args = args[1], args[2:end]
@@ -444,36 +478,39 @@ function compile_call_expr(mod, compiler, options, args...)
444478
end
445479
end
446480
call = only(args)
447-
f_symbol = gensym(:f)
448481
args_symbol = gensym(:args)
449482
compiled_symbol = gensym(:compiled)
450-
451-
if Meta.isexpr(call, :call)
452-
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
453-
fname = if bcast
454-
quote
455-
if isdefined(mod, $(Meta.quot(fname_full)))
456-
$(fname_full)
457-
else
458-
Base.Broadcast.BroadcastFunction($(fname))
483+
closure = ()
484+
if call isa Expr && need_wrap(call)
485+
(args_rhs, closure, fname) = wrapped_expression(call)
486+
else
487+
if Meta.isexpr(call, :call)
488+
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
489+
fname = if bcast
490+
quote
491+
if isdefined(mod, $(Meta.quot(fname_full)))
492+
$(fname_full)
493+
else
494+
Base.Broadcast.BroadcastFunction($(fname))
495+
end
459496
end
497+
else
498+
:($(fname))
460499
end
500+
args_rhs = Expr(:tuple, call.args[2:end]...)
501+
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
502+
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
503+
args_rhs = only(call.args[2:end])
461504
else
462-
:($(fname))
505+
error("Invalid function call: $(call)")
463506
end
464-
args_rhs = Expr(:tuple, call.args[2:end]...)
465-
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
466-
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
467-
args_rhs = only(call.args[2:end])
468-
else
469-
error("Invalid function call: $(call)")
470507
end
471508

472509
return quote
473-
$(f_symbol) = $(fname)
510+
$closure
474511
$(args_symbol) = $(args_rhs)
475512
$(compiled_symbol) = $(compiler)(
476-
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
513+
$(fname), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
477514
)
478515
end,
479516
(; compiled=compiled_symbol, args=args_symbol)

test/basic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ f_var(args...) = sum(args)
101101
@test @jit(f_var(x, y, z)) [6.6, 6.6, 6.6]
102102
end
103103

104+
@testset "Complex expression" begin
105+
x = Reactant.to_rarray(ones(3))
106+
f(x) = x .+ 1
107+
@test @jit(x + x - x + x * float(Base.pi) * 0) x
108+
@test @jit(f(f(f(f(x))))) @allowscalar x .+ 4
109+
end
110+
104111
function sumcos(x)
105112
return sum(cos.(x))
106113
end

0 commit comments

Comments
 (0)