Skip to content

Commit f771bcb

Browse files
committed
update legacy inference API
1 parent ebec467 commit f771bcb

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

src/probprog/Inference.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@ function metropolis_hastings(
1010
error("MH requires a trace with fn and rng recorded (use generate to create trace)")
1111
end
1212

13-
constraints = Dict{Symbol,Any}()
14-
constrained_symbols = Set{Symbol}()
15-
13+
constraint_pairs = Pair{Symbol,Any}[]
1614
for (sym, val) in trace.choices
1715
if !(sym in sel)
18-
constraints[sym] = val
19-
push!(constrained_symbols, sym)
16+
push!(constraint_pairs, sym => val)
2017
end
2118
end
19+
constraint = Constraint(constraint_pairs...)
20+
21+
constrained_addresses = extract_addresses(constraint)
2222

23-
cache_key = (typeof(trace.fn), constrained_symbols)
23+
cache_key = (typeof(trace.fn), constrained_addresses)
2424

2525
compiled_fn = nothing
2626
if compiled_cache !== nothing
@@ -30,12 +30,12 @@ function metropolis_hastings(
3030
if compiled_fn === nothing
3131
function wrapper_fn(rng, constraint_ptr, args...)
3232
return generate_internal(
33-
rng, trace.fn, args...; constraint_ptr, constrained_symbols
33+
rng, trace.fn, args...; constraint_ptr, constrained_addresses
3434
)
3535
end
3636

3737
constraint_ptr = ConcreteRNumber(
38-
reinterpret(UInt64, pointer_from_objref(constraints))
38+
reinterpret(UInt64, pointer_from_objref(constraint))
3939
)
4040

4141
compiled_fn = @compile optimize = :probprog wrapper_fn(
@@ -47,7 +47,7 @@ function metropolis_hastings(
4747
end
4848
end
4949

50-
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraints)))
50+
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
5151

5252
old_gc_state = GC.enable(false)
5353
new_trace_ptr = nothing

src/probprog/Types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default)
7070
extract_addresses(constraint::Constraint) = Set(keys(constraint))
7171

7272
const Selection = Set{Symbol}
73-
const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any}
73+
const CompiledFnCache = Dict{Tuple{Type,Set{Address}},Any}
7474

7575
const _probprog_ref_lock = ReentrantLock()
7676
const _probprog_refs = IdDict()

test/probprog/linear_regression.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@ end
3333
function my_inference_program(xs, ys, num_iters)
3434
xs_r = Reactant.to_rarray(xs)
3535

36-
constraint = ProbProg.choicemap()
37-
constraint[:ys] = [ys]
36+
observations = ProbProg.Constraint(:ys => (ys,))
3837

3938
seed = Reactant.to_rarray(UInt64[1, 4])
4039
rng = ReactantRNG(seed)
4140

42-
trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint)
41+
trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint=observations)
4342

4443
trace = ProbProg.with_compiled_cache() do cache
4544
local t = trace

0 commit comments

Comments
 (0)