@@ -222,6 +222,75 @@ function sample(
222
222
return result
223
223
end
224
224
225
+ function call (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
226
+ res = @jit optimize = :probprog call_internal (f, args... )
227
+ return res isa AbstractConcreteArray ? Array (res) : res
228
+ end
229
+
230
+ function call_internal (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
231
+ argprefix:: Symbol = gensym (" callarg" )
232
+ resprefix:: Symbol = gensym (" callresult" )
233
+ resargprefix:: Symbol = gensym (" callresarg" )
234
+
235
+ mlir_fn_res = invokelatest (
236
+ TracedUtils. make_mlir_fn,
237
+ f,
238
+ args,
239
+ (),
240
+ string (f),
241
+ false ;
242
+ do_transpose= false ,
243
+ args_in_result= :all ,
244
+ argprefix,
245
+ resprefix,
246
+ resargprefix,
247
+ )
248
+ (; result, linear_args, in_tys, linear_results) = mlir_fn_res
249
+ fnwrap = mlir_fn_res. fnwrapped
250
+ func2 = mlir_fn_res. f
251
+
252
+ out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
253
+ fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
254
+ fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
255
+
256
+ batch_inputs = MLIR. IR. Value[]
257
+ for a in linear_args
258
+ idx, path = TracedUtils. get_argidx (a, argprefix)
259
+ if idx == 1 && fnwrap
260
+ TracedUtils. push_val! (batch_inputs, f, path[3 : end ])
261
+ else
262
+ if fnwrap
263
+ idx -= 1
264
+ end
265
+ TracedUtils. push_val! (batch_inputs, args[idx], path[3 : end ])
266
+ end
267
+ end
268
+
269
+ call_op = MLIR. Dialects. enzyme. untracedCall (batch_inputs; outputs= out_tys, fn= fname)
270
+
271
+ for (i, res) in enumerate (linear_results)
272
+ resv = MLIR. IR. result (call_op, i)
273
+ if TracedUtils. has_idx (res, resprefix)
274
+ path = TracedUtils. get_idx (res, resprefix)
275
+ TracedUtils. set! (result, path[2 : end ], resv)
276
+ elseif TracedUtils. has_idx (res, argprefix)
277
+ idx, path = TracedUtils. get_argidx (res, argprefix)
278
+ if idx == 1 && fnwrap
279
+ TracedUtils. set! (f, path[3 : end ], resv)
280
+ else
281
+ if fnwrap
282
+ idx -= 1
283
+ end
284
+ TracedUtils. set! (args[idx], path[3 : end ], resv)
285
+ end
286
+ else
287
+ TracedUtils. set! (res, (), resv)
288
+ end
289
+ end
290
+
291
+ return result
292
+ end
293
+
225
294
function generate (f:: Function , args:: Vararg{Any,Nargs} ; constraints= nothing ) where {Nargs}
226
295
trace = ProbProgTrace ()
227
296
0 commit comments