@@ -14,6 +14,8 @@ import ..Reactant:
1414 append_path,
1515 TracedType
1616
17+ using ExpressionExplorer
18+
1719@inline traced_getfield (@nospecialize (obj), field) = Base. getfield (obj, field)
1820@inline traced_getfield (
1921 @nospecialize (obj:: AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}} ), field
@@ -432,6 +434,38 @@ macro jit(args...)
432434 # ! format: on
433435end
434436
437+ is_a_module (s:: Symbol ):: Bool = begin
438+ isdefined (@__MODULE__ , s) && getproperty (@__MODULE__ , s) isa Module
439+ end
440+
441+ # create expression for more complex expression than a call
442+ function wrapped_expression (expr:: Expr )
443+ args = ExpressionExplorer. compute_symbols_state (expr). references
444+ args = filter (! is_a_module, args)
445+ args = tuple (collect (args)... )
446+ fname = gensym (:F )
447+
448+ return (
449+ Expr (:tuple , args... ),
450+ quote
451+ ($ fname)($ (args... )) = $ expr
452+ end ,
453+ quote
454+ $ (fname)
455+ end ,
456+ )
457+ end
458+
459+ # check if an expression need to be wrap in a closure
460+ function need_wrap (expr:: Expr ):: Bool
461+ for arg in expr. args
462+ arg isa Expr || continue
463+ Meta. isexpr (arg, :.) && continue
464+ return true
465+ end
466+ return false
467+ end
468+
435469function compile_call_expr (mod, compiler, options, args... )
436470 while length (args) > 1
437471 option, args = args[1 ], args[2 : end ]
@@ -444,36 +478,39 @@ function compile_call_expr(mod, compiler, options, args...)
444478 end
445479 end
446480 call = only (args)
447- f_symbol = gensym (:f )
448481 args_symbol = gensym (:args )
449482 compiled_symbol = gensym (:compiled )
450-
451- if Meta. isexpr (call, :call )
452- bcast, fname, fname_full = correct_maybe_bcast_call (call. args[1 ])
453- fname = if bcast
454- quote
455- if isdefined (mod, $ (Meta. quot (fname_full)))
456- $ (fname_full)
457- else
458- Base. Broadcast. BroadcastFunction ($ (fname))
483+ closure = ()
484+ if call isa Expr && need_wrap (call)
485+ (args_rhs, closure, fname) = wrapped_expression (call)
486+ else
487+ if Meta. isexpr (call, :call )
488+ bcast, fname, fname_full = correct_maybe_bcast_call (call. args[1 ])
489+ fname = if bcast
490+ quote
491+ if isdefined (mod, $ (Meta. quot (fname_full)))
492+ $ (fname_full)
493+ else
494+ Base. Broadcast. BroadcastFunction ($ (fname))
495+ end
459496 end
497+ else
498+ :($ (fname))
460499 end
500+ args_rhs = Expr (:tuple , call. args[2 : end ]. .. )
501+ elseif Meta. isexpr (call, :(.), 2 ) && Meta. isexpr (call. args[2 ], :tuple )
502+ fname = :($ (Base. Broadcast. BroadcastFunction)($ (call. args[1 ])))
503+ args_rhs = only (call. args[2 : end ])
461504 else
462- :( $ (fname) )
505+ error ( " Invalid function call: $(call) " )
463506 end
464- args_rhs = Expr (:tuple , call. args[2 : end ]. .. )
465- elseif Meta. isexpr (call, :(.), 2 ) && Meta. isexpr (call. args[2 ], :tuple )
466- fname = :($ (Base. Broadcast. BroadcastFunction)($ (call. args[1 ])))
467- args_rhs = only (call. args[2 : end ])
468- else
469- error (" Invalid function call: $(call) " )
470507 end
471508
472509 return quote
473- $ (f_symbol) = $ (fname)
510+ $ closure
474511 $ (args_symbol) = $ (args_rhs)
475512 $ (compiled_symbol) = $ (compiler)(
476- $ (f_symbol ), $ (args_symbol); $ (Expr .(:kw , keys (options), values (options))... )
513+ $ (fname ), $ (args_symbol); $ (Expr .(:kw , keys (options), values (options))... )
477514 )
478515 end ,
479516 (; compiled= compiled_symbol, args= args_symbol)
0 commit comments