@@ -3,7 +3,7 @@ module ReactantCore
33using ExpressionExplorer: ExpressionExplorer
44using MacroTools: MacroTools
55
6- export @trace , MissingTracedValue
6+ export @trace , within_compile, MissingTracedValue
77
88# Traits
99is_traced (x) = false
@@ -21,6 +21,13 @@ const SPECIAL_SYMBOLS = [
2121 :(:), :nothing , :missing , :Inf , :Inf16 , :Inf32 , :Inf64 , :Base , :Core
2222]
2323
24+ """
25+ within_compile()
26+
27+ Returns true if this function is executed in a Reactant compilation context, otherwise false.
28+ """
29+ @inline within_compile () = false # behavior is overwritten in Interpreter.jl
30+
2431# Code generation
2532"""
2633 @trace <expr>
@@ -117,6 +124,13 @@ macro trace(expr)
117124 return esc (trace_if_with_returns (__module__, expr))
118125 end
119126 end
127+ Meta. isexpr (expr, :call ) && return esc (trace_call (__module__, expr))
128+ if Meta. isexpr (expr, :(.), 2 ) && Meta. isexpr (expr. args[2 ], :tuple )
129+ fname = :($ (Base. Broadcast. BroadcastFunction)($ (expr. args[1 ])))
130+ args = only (expr. args[2 : end ]). args
131+ call = Expr (:call , fname, args... )
132+ return esc (trace_call (__module__, call))
133+ end
120134 Meta. isexpr (expr, :if ) && return esc (trace_if (__module__, expr))
121135 Meta. isexpr (expr, :for ) && return (esc (trace_for (__module__, expr)))
122136 return error (" Only `if-elseif-else` blocks are currently supported by `@trace`" )
@@ -196,7 +210,9 @@ function trace_for(mod, expr)
196210 end
197211
198212 return quote
199- if any ($ (is_traced), $ (Expr (:tuple , cond_val .(all_syms. args[(begin + 1 ): end ])... )))
213+ if $ (within_compile)() && $ (any)(
214+ $ (is_traced), $ (Expr (:tuple , cond_val .(all_syms. args[(begin + 1 ): end ])... ))
215+ )
200216 $ (reactant_code_block)
201217 else
202218 $ (expr)
@@ -210,7 +226,7 @@ function trace_if_with_returns(mod, expr)
210226 mod, expr. args[2 ]; store_last_line= expr. args[1 ], depth= 1
211227 )
212228 return quote
213- if any ($ (is_traced), ($ (all_check_vars... ),))
229+ if $ (within_compile)() && $ ( any) ($ (is_traced), ($ (all_check_vars... ),))
214230 $ (new_expr)
215231 else
216232 $ (expr)
@@ -356,14 +372,41 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
356372 )
357373
358374 return quote
359- if any ($ (is_traced), ($ (all_check_vars... ),))
375+ if $ (within_compile)() && $ ( any) ($ (is_traced), ($ (all_check_vars... ),))
360376 $ (reactant_code_block)
361377 else
362378 $ (original_expr)
363379 end
364380 end
365381end
366382
383+ function correct_maybe_bcast_call (fname)
384+ startswith (string (fname), ' .' ) || return false , fname, fname
385+ return true , Symbol (string (fname)[2 : end ]), fname
386+ end
387+
388+ function trace_call (mod, call)
389+ bcast, fname, fname_full = correct_maybe_bcast_call (call. args[1 ])
390+ f = if bcast
391+ quote
392+ if isdefined (mod, $ (Meta. quot (fname_full)))
393+ $ (fname_full)
394+ else
395+ Base. Broadcast. BroadcastFunction ($ (fname))
396+ end
397+ end
398+ else
399+ :($ (fname))
400+ end
401+ return quote
402+ if $ (within_compile)()
403+ $ (traced_call)($ f, $ (call. args[2 : end ]. .. ))
404+ else
405+ $ (call)
406+ end
407+ end
408+ end
409+
367410function remove_shortcircuiting (expr)
368411 return MacroTools. prewalk (expr) do x
369412 if MacroTools. @capture (x, a_ && b_)
382425
383426function traced_while end # defined inside Reactant.jl
384427
428+ traced_call (f, args... ; kwargs... ) = f (args... ; kwargs... )
429+
385430function cleanup_expr_to_avoid_boxing (expr, prepend:: Symbol , all_vars)
386431 return MacroTools. postwalk (expr) do x
387432 if Meta. isexpr (x, :kw ) # undo lhs rewriting
0 commit comments