Skip to content

Commit 234d168

Browse files
authored
@trace function calls (#366)
1 parent cac6f49 commit 234d168

File tree

11 files changed

+466
-48
lines changed

11 files changed

+466
-48
lines changed

docs/src/api/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ Reactant.@jit
1313

1414
## ReactantCore API
1515

16+
```@docs
17+
within_compile
18+
```
19+
1620
```@docs
1721
@trace
1822
```

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ReactantCore
33
using ExpressionExplorer: ExpressionExplorer
44
using MacroTools: MacroTools
55

6-
export @trace, MissingTracedValue
6+
export @trace, within_compile, MissingTracedValue
77

88
# Traits
99
is_traced(x) = false
@@ -21,6 +21,13 @@ const SPECIAL_SYMBOLS = [
2121
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
2222
]
2323

24+
"""
25+
within_compile()
26+
27+
Returns true if this function is executed in a Reactant compilation context, otherwise false.
28+
"""
29+
@inline within_compile() = false # behavior is overwritten in Interpreter.jl
30+
2431
# Code generation
2532
"""
2633
@trace <expr>
@@ -117,6 +124,13 @@ macro trace(expr)
117124
return esc(trace_if_with_returns(__module__, expr))
118125
end
119126
end
127+
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr))
128+
if Meta.isexpr(expr, :(.), 2) && Meta.isexpr(expr.args[2], :tuple)
129+
fname = :($(Base.Broadcast.BroadcastFunction)($(expr.args[1])))
130+
args = only(expr.args[2:end]).args
131+
call = Expr(:call, fname, args...)
132+
return esc(trace_call(__module__, call))
133+
end
120134
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
121135
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
122136
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
@@ -196,7 +210,9 @@ function trace_for(mod, expr)
196210
end
197211

198212
return quote
199-
if any($(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...)))
213+
if $(within_compile)() && $(any)(
214+
$(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...))
215+
)
200216
$(reactant_code_block)
201217
else
202218
$(expr)
@@ -210,7 +226,7 @@ function trace_if_with_returns(mod, expr)
210226
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
211227
)
212228
return quote
213-
if any($(is_traced), ($(all_check_vars...),))
229+
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
214230
$(new_expr)
215231
else
216232
$(expr)
@@ -356,14 +372,41 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
356372
)
357373

358374
return quote
359-
if any($(is_traced), ($(all_check_vars...),))
375+
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
360376
$(reactant_code_block)
361377
else
362378
$(original_expr)
363379
end
364380
end
365381
end
366382

383+
function correct_maybe_bcast_call(fname)
384+
startswith(string(fname), '.') || return false, fname, fname
385+
return true, Symbol(string(fname)[2:end]), fname
386+
end
387+
388+
function trace_call(mod, call)
389+
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
390+
f = if bcast
391+
quote
392+
if isdefined(mod, $(Meta.quot(fname_full)))
393+
$(fname_full)
394+
else
395+
Base.Broadcast.BroadcastFunction($(fname))
396+
end
397+
end
398+
else
399+
:($(fname))
400+
end
401+
return quote
402+
if $(within_compile)()
403+
$(traced_call)($f, $(call.args[2:end]...))
404+
else
405+
$(call)
406+
end
407+
end
408+
end
409+
367410
function remove_shortcircuiting(expr)
368411
return MacroTools.prewalk(expr) do x
369412
if MacroTools.@capture(x, a_ && b_)
@@ -382,6 +425,8 @@ end
382425

383426
function traced_while end # defined inside Reactant.jl
384427

428+
traced_call(f, args...; kwargs...) = f(args...; kwargs...)
429+
385430
function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
386431
return MacroTools.postwalk(expr) do x
387432
if Meta.isexpr(x, :kw) # undo lhs rewriting

src/Compiler.jl

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import ..Reactant:
2020
ancestor,
2121
TracedType
2222

23+
import ..ReactantCore: correct_maybe_bcast_call
24+
2325
@inline function traced_getfield(@nospecialize(obj), field)
2426
return Base.getfield(obj, field)
2527
end
@@ -440,18 +442,34 @@ const DEBUG_KERNEL = Ref{Bool}(false)
440442
const DUMP_LLVMIR = Ref{Bool}(false)
441443

442444
function compile_mlir!(
443-
mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu"
445+
mod,
446+
f,
447+
args,
448+
callcache=Dict{
449+
Vector,
450+
@NamedTuple{
451+
f_name::String,
452+
mlir_result_types::Vector{MLIR.IR.Type},
453+
traced_result::Any,
454+
mutated::Vector{Int},
455+
}
456+
}();
457+
optimize::Union{Bool,Symbol}=true,
458+
no_nan::Bool=false,
459+
backend="gpu",
444460
)
445461
# Explicitly don't use block! to avoid creating a closure, which creates
446462
# both compile-time and relocatability issues
447463

448464
MLIR.IR.activate!(mod)
449465
MLIR.IR.activate!(MLIR.IR.body(mod))
466+
activate_callcache!(callcache)
450467
fnwrapped,
451468
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
452469
linear_results = try
453470
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
454471
finally
472+
deactivate_callcache!(callcache)
455473
MLIR.IR.deactivate!(MLIR.IR.body(mod))
456474
MLIR.IR.deactivate!(mod)
457475
end
@@ -716,11 +734,6 @@ function compile_call_expr(mod, compiler, options, args...)
716734
(; compiled=compiled_symbol, args=args_symbol)
717735
end
718736

719-
function correct_maybe_bcast_call(fname)
720-
startswith(string(fname), '.') || return false, fname, fname
721-
return true, Symbol(string(fname)[2:end]), fname
722-
end
723-
724737
"""
725738
codegen_flatten!
726739
@@ -1167,4 +1180,40 @@ function register_thunk(
11671180
return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f)
11681181
end
11691182

1183+
function activate_callcache!(callcache)
1184+
stack = get!(task_local_storage(), :callcache) do
1185+
return []
1186+
end
1187+
push!(stack, callcache)
1188+
return nothing
1189+
end
1190+
1191+
function deactivate_callcache!(callcache)
1192+
callcache === last(task_local_storage(:callcache)) ||
1193+
error("Deactivating wrong callcache")
1194+
return pop!(task_local_storage(:callcache))
1195+
end
1196+
1197+
function _has_callcache()
1198+
return haskey(task_local_storage(), :callcache) &&
1199+
!Base.isempty(task_local_storage(:callcache))
1200+
end
1201+
1202+
function callcache(; throw_error::Bool=true)
1203+
if !_has_callcache()
1204+
throw_error && error("No callcache is active")
1205+
return nothing
1206+
end
1207+
return last(task_local_storage(:callcache))
1208+
end
1209+
1210+
function callcache!(f, callcache)
1211+
activate_callcache!(callcache)
1212+
try
1213+
return f()
1214+
finally
1215+
deactivate_callcache!(callcache)
1216+
end
1217+
end
1218+
11701219
end

src/ControlFlow.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ function ReactantCore.traced_if(
44
return Ops.if_condition(cond, true_fn, false_fn, args...)
55
end
66

7+
function ReactantCore.traced_call(f::Function, args...)
8+
return Ops.call(f, args...)
9+
end
10+
711
function ReactantCore.traced_while(cond_fn::CFn, body_fn::BFn, args) where {CFn,BFn}
812
return Ops.while_loop(cond_fn, body_fn, args...)
913
end

src/Interpreter.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,25 @@ function set_reactant_abi(
3939
)
4040
(; fargs, argtypes) = arginfo
4141

42+
if f === ReactantCore.within_compile
43+
if length(argtypes) != 1
44+
@static if VERSION < v"1.11.0-"
45+
return CallMeta(Union{}, Effects(), NoCallInfo())
46+
else
47+
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
48+
end
49+
end
50+
@static if VERSION < v"1.11.0-"
51+
return CallMeta(
52+
Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
53+
)
54+
else
55+
return CallMeta(
56+
Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
57+
)
58+
end
59+
end
60+
4261
# Improve inference by considering call_with_reactant as having the same results as
4362
# the original call
4463
if f === Reactant.call_with_reactant
@@ -236,7 +255,7 @@ function overload_autodiff(
236255
primf = f.val
237256
primargs = ((v.val for v in args)...,)
238257

239-
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
258+
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results, _ = TracedUtils.make_mlir_fn(
240259
primf, primargs, (), string(f) * "_autodiff", false
241260
)
242261

src/Ops.jl

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ end
10701070
(sample_inputs...,),
10711071
(),
10721072
"comparator";
1073-
no_args_in_result=true,
1073+
args_in_result=:none,
10741074
return_dialect=:stablehlo,
10751075
)[2]
10761076
@assert MLIR.IR.nregions(func) == 1
@@ -1679,7 +1679,7 @@ end
16791679
string(gensym("cond_fn")),
16801680
false;
16811681
return_dialect=:stablehlo,
1682-
no_args_in_result=true,
1682+
args_in_result=:none,
16831683
do_transpose=false,
16841684
)
16851685

@@ -1690,7 +1690,7 @@ end
16901690
string(gensym("body_fn")),
16911691
false;
16921692
return_dialect=:stablehlo,
1693-
no_args_in_result=true,
1693+
args_in_result=:none,
16941694
do_transpose=false,
16951695
)
16961696

@@ -2060,4 +2060,75 @@ end
20602060
return corrected_traced_results
20612061
end
20622062

2063+
@noinline function call(f, args...)
2064+
seen_cache = Reactant.OrderedIdDict()
2065+
Reactant.make_tracer(
2066+
seen_cache,
2067+
args,
2068+
(), # we have to insert something here, but we remove it immediately below.
2069+
Reactant.TracedTrack;
2070+
toscalar=false,
2071+
)
2072+
linear_args = []
2073+
mlir_caller_args = Reactant.MLIR.IR.Value[]
2074+
for (k, v) in seen_cache
2075+
v isa Reactant.TracedType || continue
2076+
push!(linear_args, v)
2077+
push!(mlir_caller_args, v.mlir_data)
2078+
# make tracer inserted `()` into the path, here we remove it:
2079+
v.paths = v.paths[1:(end - 1)]
2080+
end
2081+
2082+
seen = Dict()
2083+
cache_key = []
2084+
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
2085+
cache = Reactant.Compiler.callcache()
2086+
if haskey(cache, cache_key)
2087+
# cache lookup:
2088+
(; f_name, mlir_result_types, traced_result, mutated) = cache[cache_key]
2089+
else
2090+
f_name = String(gensym(Symbol(f)))
2091+
temp = Reactant.TracedUtils.make_mlir_fn(
2092+
f, args, (), f_name, false; args_in_result=:mutated, do_transpose=false
2093+
)
2094+
traced_result, ret, mutated = temp[[3, 6, 10]]
2095+
mlir_result_types = [
2096+
MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)
2097+
]
2098+
cache[cache_key] = (; f_name, mlir_result_types, traced_result, mutated)
2099+
end
2100+
2101+
call_op = MLIR.Dialects.func.call(
2102+
mlir_caller_args;
2103+
result_0=mlir_result_types,
2104+
callee=MLIR.IR.FlatSymbolRefAttribute(f_name),
2105+
)
2106+
2107+
seen_results = Reactant.OrderedIdDict()
2108+
traced_result = Reactant.make_tracer(
2109+
seen_results,
2110+
traced_result,
2111+
(), # we have to insert something here, but we remove it immediately below.
2112+
Reactant.TracedSetPath;
2113+
toscalar=false,
2114+
)
2115+
i = 1
2116+
for (k, v) in seen_results
2117+
v isa Reactant.TracedType || continue
2118+
# this mutates `traced_result`, which is what we want:
2119+
v.mlir_data = MLIR.IR.result(call_op, i)
2120+
# make tracer inserted `()` into the path, here we remove it:
2121+
v.paths = v.paths[1:(end - 1)]
2122+
i += 1
2123+
end
2124+
nres = MLIR.IR.nresults(call_op)
2125+
# mutated args are included as the last ones in the call op results
2126+
for (result_i, arg_i) in zip((nres - length(mutated)):nres, mutated)
2127+
Reactant.TracedUtils.set_mlir_data!(
2128+
linear_args[arg_i], MLIR.IR.result(call_op, result_i + 1)
2129+
)
2130+
end
2131+
return traced_result
2132+
end
2133+
20632134
end # module Ops

src/Reactant.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Reactant
22

3-
using ReactantCore: ReactantCore, @trace, MissingTracedValue
3+
using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue
44

55
using LinearAlgebra: LinearAlgebra
66
using Random: Random, AbstractRNG
@@ -231,7 +231,7 @@ function Enzyme.make_zero(
231231
end
232232

233233
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
234-
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace
234+
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, within_compile
235235

236236
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()
237237

0 commit comments

Comments
 (0)