Skip to content

Commit 5b5c1d1

Browse files
committed
generate op with constraints
1 parent 6e4dc0c commit 5b5c1d1

File tree

3 files changed

+79
-6
lines changed

3 files changed

+79
-6
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,20 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) {
353353
(mlir::enzyme::Activity)val));
354354
}
355355

356+
extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet(
357+
MlirContext ctx, uint64_t symbol, MlirAttribute values) {
358+
mlir::Attribute vals = unwrap(values);
359+
auto arr = llvm::dyn_cast<mlir::ArrayAttr>(vals);
360+
if (!arr) {
361+
ReactantThrowError(
362+
"enzymeConstraintAttrGet: `values` must be an ArrayAttr");
363+
return MlirAttribute{nullptr};
364+
}
365+
mlir::Attribute attr =
366+
mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr);
367+
return wrap(attr);
368+
}
369+
356370
// Create profiler session and start profiling
357371
extern "C" tsl::ProfilerSession *
358372
CreateProfilerSession(uint32_t device_tracer_level,

src/ProbProg.jl

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ function sample(
146146
in_idx = nothing
147147
for (i, arg) in enumerate(linear_args)
148148
if TracedUtils.has_idx(arg, argprefix) &&
149-
TracedUtils.get_idx(arg, argprefix) == TracedUtils.get_idx(res, argprefix)
149+
TracedUtils.get_idx(arg, argprefix) ==
150+
TracedUtils.get_idx(res, argprefix)
150151
in_idx = i - 1
151152
break
152153
end
@@ -221,10 +222,12 @@ function sample(
221222
return result
222223
end
223224

224-
function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
225+
function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs}
225226
trace = ProbProgTrace()
226227

227-
weight, res = @jit optimize = :probprog generate_internal(f, args...; trace)
228+
weight, res = @jit sync = true optimize = :probprog generate_internal(
229+
f, args...; trace, constraints
230+
)
228231

229232
trace.retval = res isa AbstractConcreteArray ? Array(res) : res
230233
trace.weight = Array(weight)[1]
@@ -233,7 +236,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
233236
end
234237

235238
function generate_internal(
236-
f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace
239+
f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace, constraints=nothing
237240
) where {Nargs}
238241
argprefix::Symbol = gensym("generatearg")
239242
resprefix::Symbol = gensym("generateresult")
@@ -276,9 +279,46 @@ function generate_internal(
276279

277280
trace_addr = reinterpret(UInt64, pointer_from_objref(trace))
278281

279-
# Output: (weight, f's outputs...)
282+
constraints_attr = nothing
283+
if constraints !== nothing && !isempty(constraints)
284+
constraint_attrs = MLIR.IR.Attribute[]
285+
286+
for (sym, constraint) in constraints
287+
sym_addr = reinterpret(UInt64, pointer_from_objref(sym))
288+
289+
if !(constraint isa AbstractArray)
290+
error(
291+
"Constraints must be an array (one element per traced output) of arrays"
292+
)
293+
end
294+
295+
sym_constraint_attrs = MLIR.IR.Attribute[]
296+
for oc in constraint
297+
if !(oc isa AbstractArray)
298+
error("Per-output constraints must be arrays")
299+
end
300+
301+
push!(sym_constraint_attrs, MLIR.IR.DenseElementsAttribute(oc))
302+
end
303+
304+
cattr_ptr = @ccall MLIR.API.mlir_c.enzymeConstraintAttrGet(
305+
MLIR.IR.context()::MLIR.API.MlirContext,
306+
sym_addr::UInt64,
307+
MLIR.IR.Attribute(sym_constraint_attrs)::MLIR.API.MlirAttribute,
308+
)::MLIR.API.MlirAttribute
309+
310+
push!(constraint_attrs, MLIR.IR.Attribute(cattr_ptr))
311+
end
312+
313+
constraints_attr = MLIR.IR.Attribute(constraint_attrs)
314+
end
315+
280316
gen_op = MLIR.Dialects.enzyme.generate(
281-
batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr
317+
batch_inputs;
318+
outputs=out_tys,
319+
fn=fname,
320+
trace=trace_addr,
321+
constraints=constraints_attr,
282322
)
283323

284324
weight = TracedRArray(MLIR.IR.result(gen_op, 1))

test/probprog/generate.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,23 @@ end
4040
trace, weight = ProbProg.generate(model, seed, μ, σ, shape)
4141
@test mean(trace.retval) 0.0 atol = 0.05 rtol = 0.05
4242
end
43+
44+
@testset "constraints" begin
45+
shape = (10,)
46+
seed = Reactant.to_rarray(UInt64[1, 4])
47+
μ = Reactant.ConcreteRNumber(0.0)
48+
σ = Reactant.ConcreteRNumber(1.0)
49+
50+
s_constraint = fill(0.1, shape)
51+
constraints = Dict(:s => [s_constraint])
52+
53+
trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraints)
54+
55+
@test trace.choices[:s] == s_constraint
56+
57+
expected_weight =
58+
normal_logpdf(s_constraint, 0.0, 1.0, shape) +
59+
normal_logpdf(trace.choices[:t], s_constraint, 1.0, shape)
60+
@test weight expected_weight atol = 1e-6
61+
end
4362
end

0 commit comments

Comments
 (0)