Skip to content

Commit 1ad167a

Browse files
committed
untraced call
1 parent 5b5c1d1 commit 1ad167a

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

src/ProbProg.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,75 @@ function sample(
222222
return result
223223
end
224224

225+
function call(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
226+
res = @jit optimize = :probprog call_internal(f, args...)
227+
return res isa AbstractConcreteArray ? Array(res) : res
228+
end
229+
230+
function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
231+
argprefix::Symbol = gensym("callarg")
232+
resprefix::Symbol = gensym("callresult")
233+
resargprefix::Symbol = gensym("callresarg")
234+
235+
mlir_fn_res = invokelatest(
236+
TracedUtils.make_mlir_fn,
237+
f,
238+
args,
239+
(),
240+
string(f),
241+
false;
242+
do_transpose=false,
243+
args_in_result=:all,
244+
argprefix,
245+
resprefix,
246+
resargprefix,
247+
)
248+
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
249+
fnwrap = mlir_fn_res.fnwrapped
250+
func2 = mlir_fn_res.f
251+
252+
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
253+
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
254+
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
255+
256+
batch_inputs = MLIR.IR.Value[]
257+
for a in linear_args
258+
idx, path = TracedUtils.get_argidx(a, argprefix)
259+
if idx == 1 && fnwrap
260+
TracedUtils.push_val!(batch_inputs, f, path[3:end])
261+
else
262+
if fnwrap
263+
idx -= 1
264+
end
265+
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
266+
end
267+
end
268+
269+
call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fname)
270+
271+
for (i, res) in enumerate(linear_results)
272+
resv = MLIR.IR.result(call_op, i)
273+
if TracedUtils.has_idx(res, resprefix)
274+
path = TracedUtils.get_idx(res, resprefix)
275+
TracedUtils.set!(result, path[2:end], resv)
276+
elseif TracedUtils.has_idx(res, argprefix)
277+
idx, path = TracedUtils.get_argidx(res, argprefix)
278+
if idx == 1 && fnwrap
279+
TracedUtils.set!(f, path[3:end], resv)
280+
else
281+
if fnwrap
282+
idx -= 1
283+
end
284+
TracedUtils.set!(args[idx], path[3:end], resv)
285+
end
286+
else
287+
TracedUtils.set!(res, (), resv)
288+
end
289+
end
290+
291+
return result
292+
end
293+
225294
function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs}
226295
trace = ProbProgTrace()
227296

test/probprog/sample.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ end
2424
seed = Reactant.to_rarray(UInt64[1, 4])
2525
μ = Reactant.ConcreteRNumber(0.0)
2626
σ = Reactant.ConcreteRNumber(1.0)
27-
before = @code_hlo optimize = false ProbProg.generate_internal(one_sample, seed, μ, σ, shape)
27+
before = @code_hlo optimize = false ProbProg.call_internal(one_sample, seed, μ, σ, shape)
2828
@test contains(repr(before), "enzyme.sample")
29-
after = @code_hlo optimize = :probprog ProbProg.generate_internal(two_samples, seed, μ, σ, shape)
29+
after = @code_hlo optimize = :probprog ProbProg.call_internal(two_samples, seed, μ, σ, shape)
3030
@test !contains(repr(after), "enzyme.sample")
3131
end
3232

@@ -35,8 +35,8 @@ end
3535
seed = Reactant.to_rarray(UInt64[1, 4])
3636
μ = Reactant.ConcreteRNumber(0.0)
3737
σ = Reactant.ConcreteRNumber(1.0)
38-
X = ProbProg.generate(one_sample, seed, μ, σ, shape)
39-
Y = ProbProg.generate(two_samples, seed, μ, σ, shape)
38+
X = ProbProg.call(one_sample, seed, μ, σ, shape)
39+
Y = ProbProg.call(two_samples, seed, μ, σ, shape)
4040
@test !all(X .≈ Y)
4141
end
4242
end

0 commit comments

Comments
 (0)