Skip to content

Commit 2b81db9

Browse files
committed
refactored mh inference steps with new calling convention enforced
1 parent c57a1e4 commit 2b81db9

File tree

2 files changed

+121
-26
lines changed

2 files changed

+121
-26
lines changed

src/ProbProg.jl

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using ..Reactant:
1313
using ..Compiler: @jit, @compile
1414
using Enzyme
1515
using Base: ReentrantLock
16+
using Random
1617

1718
mutable struct ProbProgTrace
1819
fn::Union{Nothing,Function}
@@ -21,13 +22,18 @@ mutable struct ProbProgTrace
2122
retval::Any
2223
weight::Any
2324
subtraces::Dict{Symbol,Any}
25+
rng::Union{Nothing,AbstractRNG}
2426

2527
function ProbProgTrace(fn::Function, args::Tuple)
26-
return new(fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}())
28+
return new(
29+
fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing
30+
)
2731
end
2832

2933
function ProbProgTrace()
30-
return new(nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}())
34+
return new(
35+
nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing
36+
)
3137
end
3238
end
3339

@@ -587,6 +593,10 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
587593

588594
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))
589595

596+
trace.fn = f
597+
trace.args = args
598+
trace.rng = rng
599+
590600
return trace, trace.weight
591601
end
592602

@@ -702,7 +712,7 @@ function generate(
702712
trace = nothing
703713

704714
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
705-
constrained_symbols = collect(keys(constraint))
715+
constrained_symbols = Set(keys(constraint))
706716

707717
function wrapper_fn(rng, constraint_ptr, args...)
708718
return generate_internal(rng, f, args...; constraint_ptr, constrained_symbols)
@@ -719,6 +729,10 @@ function generate(
719729

720730
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))
721731

732+
trace.fn = f
733+
trace.args = args
734+
trace.rng = rng
735+
722736
return trace, trace.weight
723737
end
724738

@@ -727,7 +741,7 @@ function generate_internal(
727741
f::Function,
728742
args::Vararg{Any,Nargs};
729743
constraint_ptr::TracedRNumber,
730-
constrained_symbols::Vector{Symbol},
744+
constrained_symbols::Set{Symbol},
731745
) where {Nargs}
732746
argprefix::Symbol = gensym("generatearg")
733747
resprefix::Symbol = gensym("generateresult")
@@ -947,4 +961,80 @@ end
947961

948962
get_choices(trace::ProbProgTrace) = trace.choices
949963

964+
const Selection = Set{Symbol}
965+
select(syms::Symbol...) = Set(syms)
966+
choicemap() = Constraint()
967+
const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any}
968+
969+
function metropolis_hastings(
970+
trace::ProbProgTrace,
971+
sel::Selection;
972+
compiled_cache::Union{Nothing,CompiledFnCache}=nothing,
973+
)
974+
if trace.fn === nothing || trace.rng === nothing
975+
error("MH requires a trace with fn and rng recorded (use generate to create trace)")
976+
end
977+
978+
constraints = Dict{Symbol,Any}()
979+
constrained_symbols = Set{Symbol}()
980+
981+
for (sym, val) in trace.choices
982+
if !(sym in sel)
983+
constraints[sym] = val
984+
push!(constrained_symbols, sym)
985+
end
986+
end
987+
988+
cache_key = (typeof(trace.fn), constrained_symbols)
989+
990+
compiled_fn = nothing
991+
if compiled_cache !== nothing
992+
compiled_fn = get(compiled_cache, cache_key, nothing)
993+
end
994+
995+
if compiled_fn === nothing
996+
function wrapper_fn(rng, constraint_ptr, args...)
997+
return generate_internal(
998+
rng, trace.fn, args...; constraint_ptr, constrained_symbols
999+
)
1000+
end
1001+
1002+
constraint_ptr = ConcreteRNumber(
1003+
reinterpret(UInt64, pointer_from_objref(constraints))
1004+
)
1005+
1006+
compiled_fn = @compile optimize = :probprog wrapper_fn(
1007+
trace.rng, constraint_ptr, trace.args...
1008+
)
1009+
1010+
if compiled_cache !== nothing
1011+
compiled_cache[cache_key] = compiled_fn
1012+
end
1013+
end
1014+
1015+
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraints)))
1016+
1017+
old_gc_state = GC.enable(false)
1018+
new_trace_ptr = nothing
1019+
try
1020+
new_trace_ptr, _, _ = compiled_fn(trace.rng, constraint_ptr, trace.args...)
1021+
finally
1022+
GC.enable(old_gc_state)
1023+
end
1024+
1025+
new_trace = unsafe_pointer_to_objref(Ptr{Any}(Array(new_trace_ptr)[1]))
1026+
1027+
new_trace.fn = trace.fn
1028+
new_trace.args = trace.args
1029+
new_trace.rng = trace.rng
1030+
1031+
log_alpha = new_trace.weight - trace.weight
1032+
1033+
if log(rand()) < log_alpha
1034+
return (new_trace, true)
1035+
else
1036+
return (trace, false)
1037+
end
1038+
end
1039+
9501040
end

test/probprog/linear_regression.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,83 @@
11
using Reactant, Test, Random
2-
using Reactant: ProbProg
2+
using Reactant: ProbProg, ReactantRNG
33

44
# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/
55

66
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
77
normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2)
88

9-
function my_model(seed, xs)
10-
rng = Random.default_rng()
11-
Random.seed!(rng, seed)
12-
9+
function my_model(rng, xs)
1310
slope = ProbProg.sample(
14-
normal, rng, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf
11+
rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf
1512
)
1613
intercept = ProbProg.sample(
17-
normal, rng, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf
14+
rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf
1815
)
1916

2017
ys = ProbProg.sample(
21-
normal,
2218
rng,
19+
normal,
2320
slope .* xs .+ intercept,
2421
1.0,
2522
(length(xs),);
2623
symbol=:ys,
2724
logpdf=normal_logpdf,
2825
)
2926

30-
return rng.seed, ys
27+
return ys
3128
end
3229

3330
function my_inference_program(xs, ys, num_iters)
3431
xs_r = Reactant.to_rarray(xs)
3532

36-
constraints = ProbProg.choicemap()
37-
constraints[:ys] = [ys]
33+
constraint = ProbProg.choicemap()
34+
constraint[:ys] = [ys]
3835

3936
seed = Reactant.to_rarray(UInt64[1, 4])
37+
rng = ReactantRNG(seed)
4038

41-
trace, _ = ProbProg.generate(my_model, seed, xs_r; constraints)
42-
trace.args = (trace.retval[1], trace.args[2:end]...) # TODO: this is a temporary hack
39+
trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint)
40+
41+
compiled_cache = ProbProg.CompiledFnCache()
4342

4443
for i in 1:num_iters
45-
trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope))
46-
trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept))
47-
choices = ProbProg.get_choices(trace)
48-
# @show i, choices[:slope], choices[:intercept]
44+
trace, _ = ProbProg.metropolis_hastings(
45+
trace, ProbProg.select(:slope); compiled_cache
46+
)
47+
trace, _ = ProbProg.metropolis_hastings(
48+
trace, ProbProg.select(:intercept); compiled_cache
49+
)
4950
end
5051

5152
choices = ProbProg.get_choices(trace)
52-
return (choices[:slope], choices[:intercept])
53+
return (Array(choices[:slope][1])[1], Array(choices[:intercept][1])[1])
5354
end
5455

5556
@testset "linear_regression" begin
5657
@testset "simulate" begin
5758
seed = Reactant.to_rarray(UInt64[1, 4])
58-
Random.seed!(42) # For Julia side RNG
59+
rng = ReactantRNG(seed)
5960

6061
xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
6162
xs_r = Reactant.to_rarray(xs)
6263

63-
trace = ProbProg.simulate(my_model, seed, xs_r)
64+
trace, _ = ProbProg.simulate(rng, my_model, xs_r)
6465

6566
@test haskey(trace.choices, :slope)
6667
@test haskey(trace.choices, :intercept)
6768
@test haskey(trace.choices, :ys)
6869
end
6970

7071
@testset "inference" begin
72+
Random.seed!(1) # For Julia side RNG
7173
xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
7274
ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]
7375

74-
slope, intercept = my_inference_program(xs, ys, 5)
76+
slope, intercept = my_inference_program(xs, ys, 10000)
77+
78+
@show slope, intercept
7579

76-
# @show slope, intercept
80+
@test slope -2.0 rtol = 0.05
81+
@test intercept 10.0 rtol = 0.05
7782
end
7883
end

0 commit comments

Comments
 (0)