@@ -146,7 +146,8 @@ function sample(
146
146
in_idx = nothing
147
147
for (i, arg) in enumerate (linear_args)
148
148
if TracedUtils. has_idx (arg, argprefix) &&
149
- TracedUtils. get_idx (arg, argprefix) == TracedUtils. get_idx (res, argprefix)
149
+ TracedUtils. get_idx (arg, argprefix) ==
150
+ TracedUtils. get_idx (res, argprefix)
150
151
in_idx = i - 1
151
152
break
152
153
end
@@ -221,10 +222,12 @@ function sample(
221
222
return result
222
223
end
223
224
224
- function generate (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
225
+ function generate (f:: Function , args:: Vararg{Any,Nargs} ; constraints = nothing ) where {Nargs}
225
226
trace = ProbProgTrace ()
226
227
227
- weight, res = @jit optimize = :probprog generate_internal (f, args... ; trace)
228
+ weight, res = @jit sync = true optimize = :probprog generate_internal (
229
+ f, args... ; trace, constraints
230
+ )
228
231
229
232
trace. retval = res isa AbstractConcreteArray ? Array (res) : res
230
233
trace. weight = Array (weight)[1 ]
@@ -233,7 +236,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
233
236
end
234
237
235
238
function generate_internal (
236
- f:: Function , args:: Vararg{Any,Nargs} ; trace:: ProbProgTrace
239
+ f:: Function , args:: Vararg{Any,Nargs} ; trace:: ProbProgTrace , constraints = nothing
237
240
) where {Nargs}
238
241
argprefix:: Symbol = gensym (" generatearg" )
239
242
resprefix:: Symbol = gensym (" generateresult" )
@@ -276,9 +279,46 @@ function generate_internal(
276
279
277
280
trace_addr = reinterpret (UInt64, pointer_from_objref (trace))
278
281
279
- # Output: (weight, f's outputs...)
282
+ constraints_attr = nothing
283
+ if constraints != = nothing && ! isempty (constraints)
284
+ constraint_attrs = MLIR. IR. Attribute[]
285
+
286
+ for (sym, constraint) in constraints
287
+ sym_addr = reinterpret (UInt64, pointer_from_objref (sym))
288
+
289
+ if ! (constraint isa AbstractArray)
290
+ error (
291
+ " Constraints must be an array (one element per traced output) of arrays"
292
+ )
293
+ end
294
+
295
+ sym_constraint_attrs = MLIR. IR. Attribute[]
296
+ for oc in constraint
297
+ if ! (oc isa AbstractArray)
298
+ error (" Per-output constraints must be arrays" )
299
+ end
300
+
301
+ push! (sym_constraint_attrs, MLIR. IR. DenseElementsAttribute (oc))
302
+ end
303
+
304
+ cattr_ptr = @ccall MLIR. API. mlir_c. enzymeConstraintAttrGet (
305
+ MLIR. IR. context ():: MLIR.API.MlirContext ,
306
+ sym_addr:: UInt64 ,
307
+ MLIR. IR. Attribute (sym_constraint_attrs):: MLIR.API.MlirAttribute ,
308
+ ):: MLIR.API.MlirAttribute
309
+
310
+ push! (constraint_attrs, MLIR. IR. Attribute (cattr_ptr))
311
+ end
312
+
313
+ constraints_attr = MLIR. IR. Attribute (constraint_attrs)
314
+ end
315
+
280
316
gen_op = MLIR. Dialects. enzyme. generate (
281
- batch_inputs; outputs= out_tys, fn= fname, trace= trace_addr
317
+ batch_inputs;
318
+ outputs= out_tys,
319
+ fn= fname,
320
+ trace= trace_addr,
321
+ constraints= constraints_attr,
282
322
)
283
323
284
324
weight = TracedRArray (MLIR. IR. result (gen_op, 1 ))
0 commit comments