Skip to content

Commit 87ced72

Browse files
committed
generate op fixup: replacing constrained_symbols with constrained_addresses. Address struct are now the key of Constraint dicts
1 parent a29fbed commit 87ced72

File tree

5 files changed

+183
-32
lines changed

5 files changed

+183
-32
lines changed

src/probprog/FFI.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,14 @@ function getSampleFromConstraint(
157157
shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples)
158158
sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples)
159159

160-
tostore = get(constraint, symbol, nothing)
160+
tostore = get(constraint, Address(symbol), nothing)
161+
162+
if tostore === nothing
163+
@ccall printf(
164+
"No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring
165+
)::Cvoid
166+
return nothing
167+
end
161168

162169
for i in 1:num_samples
163170
ndims = ndims_array[i]
@@ -228,6 +235,37 @@ function getSampleFromConstraint(
228235
return nothing
229236
end
230237

238+
function getSubconstraint(
239+
constraint_ptr_ptr::Ptr{Ptr{Any}},
240+
symbol_ptr_ptr::Ptr{Ptr{Any}},
241+
subconstraint_ptr_ptr::Ptr{Ptr{Any}},
242+
)
243+
constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint
244+
symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol
245+
246+
subconstraint = Constraint()
247+
248+
for (key, value) in constraint
249+
if key.path[1] == symbol
250+
@assert isa(key, Address) "Expected Address type for constraint key"
251+
@assert length(key.path) > 1 "Expected composite address with length > 1"
252+
tail_address = Address(key.path[2:end])
253+
subconstraint[tail_address] = value
254+
end
255+
end
256+
257+
if isempty(subconstraint)
258+
@ccall printf(
259+
"No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring
260+
)::Cvoid
261+
return nothing
262+
end
263+
264+
_keepalive!(subconstraint)
265+
unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint))
266+
return nothing
267+
end
268+
231269
function __init__()
232270
init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},))
233271
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
@@ -297,5 +335,12 @@ function __init__()
297335
get_sample_from_constraint_ptr::Ptr{Cvoid},
298336
)::Cvoid
299337

338+
get_subconstraint_ptr = @cfunction(
339+
getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}})
340+
)
341+
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
342+
:enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid}
343+
)::Cvoid
344+
300345
return nothing
301346
end

src/probprog/Modeling.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,16 @@ function generate(
349349
rng::AbstractRNG,
350350
f::Function,
351351
args::Vararg{Any,Nargs};
352-
constraint::Constraint=Dict{Symbol,Any}(),
352+
constraint::Constraint=Constraint(),
353353
) where {Nargs}
354354
trace = nothing
355355

356356
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
357-
constrained_symbols = Set(keys(constraint))
357+
358+
constrained_addresses = _extract_addresses(constraint)
358359

359360
function wrapper_fn(rng, constraint_ptr, args...)
360-
return generate_internal(rng, f, args...; constraint_ptr, constrained_symbols)
361+
return generate_internal(rng, f, args...; constraint_ptr, constrained_addresses)
361362
end
362363

363364
compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...)
@@ -378,12 +379,18 @@ function generate(
378379
return trace, trace.weight
379380
end
380381

382+
function _extract_addresses(constraint::Constraint)
383+
addresses = Set(keys(constraint))
384+
385+
return addresses
386+
end
387+
381388
function generate_internal(
382389
rng::AbstractRNG,
383390
f::Function,
384391
args::Vararg{Any,Nargs};
385392
constraint_ptr::TracedRNumber,
386-
constrained_symbols::Set{Symbol},
393+
constrained_addresses::Set{Address},
387394
) where {Nargs}
388395
argprefix::Symbol = gensym("generatearg")
389396
resprefix::Symbol = gensym("generateresult")
@@ -441,15 +448,19 @@ function generate_internal(
441448
1,
442449
)
443450

444-
constrained_symbols_attr = MLIR.IR.Attribute[]
445-
for sym in constrained_symbols
446-
addr = reinterpret(UInt64, pointer_from_objref(sym))
447-
push!(
448-
constrained_symbols_attr,
449-
@ccall MLIR.API.mlir_c.enzymeSymbolAttrGet(
450-
MLIR.IR.context()::MLIR.API.MlirContext, addr::UInt64
451-
)::MLIR.IR.Attribute
452-
)
451+
constrained_addresses_attr = MLIR.IR.Attribute[]
452+
for address in constrained_addresses
453+
address_attr = MLIR.IR.Attribute[]
454+
for sym in address.path
455+
sym_addr = reinterpret(UInt64, pointer_from_objref(sym))
456+
push!(
457+
address_attr,
458+
@ccall MLIR.API.mlir_c.enzymeSymbolAttrGet(
459+
MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64
460+
)::MLIR.IR.Attribute
461+
)
462+
end
463+
push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr))
453464
end
454465

455466
trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet(
@@ -464,7 +475,7 @@ function generate_internal(
464475
weight=weight_ty,
465476
outputs=out_tys,
466477
fn=fn_attr,
467-
constrained_symbols=MLIR.IR.Attribute(constrained_symbols_attr),
478+
constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr),
468479
)
469480

470481
for (i, res) in enumerate(linear_results)

src/probprog/ProbProg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include("Modeling.jl")
1919
include("Inference.jl")
2020
include("Display.jl")
2121

22-
export ProbProgTrace, Constraint, Selection, CompiledFnCache
22+
export ProbProgTrace, Constraint, Selection, CompiledFnCache, Address
2323
export get_choices, select, choicemap, with_compiled_cache
2424

2525
export sample, call, simulate, generate

src/probprog/Types.jl

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,68 @@ mutable struct ProbProgTrace
2222
end
2323
end
2424

25-
const Constraint = Dict{Symbol,Any}
25+
struct Address
26+
path::Vector{Symbol}
27+
28+
Address(path::Vector{Symbol}) = new(path)
29+
end
30+
Address(sym::Symbol) = Address([sym])
31+
Address(syms::Symbol...) = Address([syms...])
32+
33+
Base.:(==)(a::Address, b::Address) = a.path == b.path
34+
Base.hash(a::Address, h::UInt) = hash(a.path, h)
35+
36+
mutable struct Constraint <: AbstractDict{Address,Any}
37+
dict::Dict{Address,Any}
38+
39+
function Constraint(pairs::Pair...)
40+
dict = Dict{Address,Any}()
41+
for pair in pairs
42+
symbols = Symbol[]
43+
current = pair
44+
while isa(current, Pair) && isa(current.first, Symbol)
45+
push!(symbols, current.first)
46+
current = current.second
47+
end
48+
dict[Address(symbols...)] = current
49+
end
50+
return new(dict)
51+
end
52+
53+
Constraint() = new(Dict{Address,Any}())
54+
Constraint(d::Dict{Address,Any}) = new(d)
55+
end
56+
57+
Base.getindex(c::Constraint, k::Address) = c.dict[k]
58+
Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v)
59+
Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k)
60+
Base.keys(c::Constraint) = keys(c.dict)
61+
Base.values(c::Constraint) = values(c.dict)
62+
Base.iterate(c::Constraint) = iterate(c.dict)
63+
Base.iterate(c::Constraint, state) = iterate(c.dict, state)
64+
Base.length(c::Constraint) = length(c.dict)
65+
Base.isempty(c::Constraint) = isempty(c.dict)
66+
Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k)
67+
Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default)
68+
2669
const Selection = Set{Symbol}
2770
const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any}
2871

29-
const _trace_ref_lock = ReentrantLock()
30-
const _trace_refs = Vector{Any}()
72+
const _probprog_ref_lock = ReentrantLock()
73+
const _probprog_refs = IdDict()
3174

32-
function _keepalive!(tr::ProbProgTrace)
33-
lock(_trace_ref_lock)
75+
function _keepalive!(tr::Any)
76+
lock(_probprog_ref_lock)
3477
try
35-
push!(_trace_refs, tr)
78+
_probprog_refs[tr] = tr
3679
finally
37-
unlock(_trace_ref_lock)
80+
unlock(_probprog_ref_lock)
3881
end
3982
return tr
4083
end
4184

4285
get_choices(trace::ProbProgTrace) = trace.choices
4386
select(syms::Symbol...) = Set(syms)
44-
choicemap() = Constraint()
4587

4688
function with_compiled_cache(f)
4789
cache = CompiledFnCache()

test/probprog/generate.jl

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,19 @@ function model(rng, μ, σ, shape)
1313
return t
1414
end
1515

16+
function two_normals(rng, μ, σ, shape)
17+
x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf)
18+
y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf)
19+
return y
20+
end
21+
22+
function nested_model(rng, μ, σ, shape)
23+
s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf)
24+
t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t)
25+
u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u)
26+
return u
27+
end
28+
1629
@testset "Generate" begin
1730
@testset "unconstrained" begin
1831
shape = (1000,)
@@ -31,15 +44,55 @@ end
3144
μ = Reactant.ConcreteRNumber(0.0)
3245
σ = Reactant.ConcreteRNumber(1.0)
3346

34-
constraint = Dict{Symbol,Any}(:s => (fill(0.1, shape),))
47+
constraint = ProbProg.Constraint(:s => (fill(0.1, shape),))
3548

3649
trace, weight = ProbProg.generate(rng, model, μ, σ, shape; constraint)
3750

38-
@test trace.choices[:s][1] == constraint[:s][1]
51+
@test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1]
3952

4053
expected_weight =
41-
normal_logpdf(constraint[:s][1], 0.0, 1.0, shape) +
42-
normal_logpdf(trace.choices[:t][1], constraint[:s][1], 1.0, shape)
54+
normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) +
55+
normal_logpdf(
56+
trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape
57+
)
58+
@test weight expected_weight atol = 1e-6
59+
end
60+
61+
@testset "composite addresses" begin
62+
shape = (10,)
63+
seed = Reactant.to_rarray(UInt64[1, 4])
64+
rng = ReactantRNG(seed)
65+
μ = Reactant.ConcreteRNumber(0.0)
66+
σ = Reactant.ConcreteRNumber(1.0)
67+
68+
constraint = ProbProg.Constraint(
69+
:s => (fill(0.1, shape),),
70+
:t => :x => (fill(0.2, shape),),
71+
:u => :y => (fill(0.3, shape),),
72+
)
73+
74+
trace, weight = ProbProg.generate(rng, nested_model, μ, σ, shape; constraint)
75+
76+
@test trace.choices[:s][1] == fill(0.1, shape)
77+
@test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape)
78+
@test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape)
79+
80+
s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape)
81+
tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape)
82+
ty_weight = normal_logpdf(
83+
trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape
84+
)
85+
ux_weight = normal_logpdf(
86+
trace.subtraces[:u].choices[:x][1],
87+
trace.subtraces[:t].choices[:y][1],
88+
1.0,
89+
shape,
90+
)
91+
uy_weight = normal_logpdf(
92+
fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape
93+
)
94+
95+
expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight
4396
@test weight expected_weight atol = 1e-6
4497
end
4598

@@ -50,24 +103,24 @@ end
50103
μ = Reactant.ConcreteRNumber(0.0)
51104
σ = Reactant.ConcreteRNumber(1.0)
52105

53-
constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),))
106+
constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),))
54107

55-
constrained_symbols = Set(keys(constraint1))
108+
constrained_addresses = ProbProg._extract_addresses(constraint1)
56109

57110
constraint_ptr1 = Reactant.ConcreteRNumber(
58111
reinterpret(UInt64, pointer_from_objref(constraint1))
59112
)
60113

61114
wrapper_fn(constraint_ptr, rng, μ, σ) = ProbProg.generate_internal(
62-
rng, model, μ, σ, shape; constraint_ptr, constrained_symbols
115+
rng, model, μ, σ, shape; constraint_ptr, constrained_addresses
63116
)
64117

65118
compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, rng, μ, σ)
66119

67120
trace1, weight = compiled_fn(constraint_ptr1, rng, μ, σ)
68121
trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1]))
69122

70-
constraint2 = Dict{Symbol,Any}(:s => (fill(0.2, shape),))
123+
constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),))
71124
constraint_ptr2 = Reactant.ConcreteRNumber(
72125
reinterpret(UInt64, pointer_from_objref(constraint2))
73126
)

0 commit comments

Comments
 (0)