Skip to content

Commit 8f66b5f

Browse files
committed
working metropolis hastings (with hacks)
1 parent 1ad167a commit 8f66b5f

File tree

2 files changed

+119
-4
lines changed

2 files changed

+119
-4
lines changed

src/ProbProg.jl

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ mutable struct ProbProgTrace
1414
choices::Dict{Symbol,Any}
1515
retval::Any
1616
weight::Any
17+
fn::Union{Nothing,Function}
18+
args::Union{Nothing,Tuple}
1719

18-
function ProbProgTrace()
19-
return new(Dict{Symbol,Any}(), nothing, nothing)
20+
function ProbProgTrace(fn::Function, args::Tuple)
21+
return new(Dict{Symbol,Any}(), nothing, nothing, fn, args)
2022
end
23+
24+
ProbProgTrace() = new(Dict{Symbol,Any}(), nothing, nothing, nothing, ())
2125
end
2226

2327
function addSampleToTraceLowered(
@@ -292,7 +296,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
292296
end
293297

294298
function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs}
295-
trace = ProbProgTrace()
299+
trace = ProbProgTrace(f, (args...,))
296300

297301
weight, res = @jit sync = true optimize = :probprog generate_internal(
298302
f, args...; trace, constraints
@@ -416,7 +420,7 @@ function generate_internal(
416420
end
417421

418422
function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
419-
trace = ProbProgTrace()
423+
trace = ProbProgTrace(f, (args...,))
420424

421425
res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace)
422426

@@ -571,4 +575,38 @@ function Base.show(io::IO, trace::ProbProgTrace)
571575
end
572576
end
573577

578+
struct Selection
579+
symbols::Vector{Symbol}
580+
end
581+
582+
select(symbol::Symbol) = Selection([symbol])
583+
584+
choicemap() = Dict{Symbol,Any}()
585+
get_choices(trace::ProbProgTrace) = trace.choices
586+
587+
function metropolis_hastings(trace::ProbProgTrace, sel::Selection)
588+
if trace.fn === nothing
589+
error("MH requires a trace with fn and args recorded")
590+
end
591+
592+
constraints = Dict{Symbol,Any}()
593+
for (sym, val) in trace.choices
594+
sym in sel.symbols && continue
595+
constraints[sym] = [val]
596+
end
597+
598+
new_trace, _ = generate(trace.fn, trace.args...; constraints)
599+
rng_state = new_trace.retval[1] # TODO: this is a temporary hack
600+
601+
log_alpha = new_trace.weight - trace.weight
602+
603+
if log(rand()) < log_alpha
604+
new_trace.args = (rng_state, new_trace.args[2:end]...)
605+
return (new_trace, true)
606+
else
607+
trace.args = (rng_state, trace.args[2:end]...)
608+
return (trace, false)
609+
end
610+
end
611+
574612
end

test/probprog/linear_regression.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using Reactant, Test, Random
2+
using Reactant: ProbProg
3+
4+
# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/
5+
6+
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
7+
normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2)
8+
9+
function my_model(seed, xs)
10+
rng = Random.default_rng()
11+
Random.seed!(rng, seed)
12+
13+
slope = ProbProg.sample(
14+
normal, rng, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf
15+
)
16+
intercept = ProbProg.sample(
17+
normal, rng, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf
18+
)
19+
20+
ys = ProbProg.sample(
21+
normal,
22+
rng,
23+
slope .* xs .+ intercept,
24+
1.0,
25+
(length(xs),);
26+
symbol=:ys,
27+
logpdf=normal_logpdf,
28+
)
29+
30+
return rng.seed, ys
31+
end
32+
33+
function my_inference_program(xs, ys, num_iters)
34+
xs_r = Reactant.to_rarray(xs)
35+
36+
constraints = ProbProg.choicemap()
37+
constraints[:ys] = [ys]
38+
39+
seed = Reactant.to_rarray(UInt64[1, 4])
40+
41+
trace, _ = ProbProg.generate(my_model, seed, xs_r; constraints)
42+
trace.args = (trace.retval[1], trace.args[2:end]...) # TODO: this is a temporary hack
43+
44+
for i in 1:num_iters
45+
trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope))
46+
trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept))
47+
choices = ProbProg.get_choices(trace)
48+
@show i, choices[:slope], choices[:intercept]
49+
end
50+
51+
choices = ProbProg.get_choices(trace)
52+
return (choices[:slope], choices[:intercept])
53+
end
54+
55+
@testset "linear_regression" begin
56+
@testset "simulate" begin
57+
seed = Reactant.to_rarray(UInt64[1, 4])
58+
59+
xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
60+
xs_r = Reactant.to_rarray(xs)
61+
62+
trace = ProbProg.simulate(my_model, seed, xs_r)
63+
64+
@test haskey(trace.choices, :slope)
65+
@test haskey(trace.choices, :intercept)
66+
@test haskey(trace.choices, :ys)
67+
end
68+
69+
@testset "inference" begin
70+
xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
71+
ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]
72+
73+
slope, intercept = my_inference_program(xs, ys, 1000)
74+
75+
@show slope, intercept
76+
end
77+
end

0 commit comments

Comments
 (0)