Skip to content

Commit 5f00a14

Browse files
committed
add experiment package
1 parent 2d66d1c commit 5f00a14

File tree

14 files changed

+752
-0
lines changed

14 files changed

+752
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name = "AutoMarginalizationExperiments"
2+
uuid = "7a1de1b0-2fb5-4cf5-9df0-9a8847935917"
3+
version = "0.1.0"
4+
5+
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
11+
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
12+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
15+
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
16+
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
17+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
18+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
19+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
20+
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
21+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
23+
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
24+
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
25+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
26+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
27+
28+
[compat]
29+
julia = "1.10, 1.11"
30+
Distributions = "0.25"
31+
ADTypes = "1"
32+
ForwardDiff = "0.10, 0.11"
33+
LogDensityProblems = "2"
34+
LogDensityProblemsAD = "1"
35+
OrdinaryDiffEq = "6"
36+
RDatasets = "0.7"
37+
BenchmarkTools = "1"
38+
Graphs = "1"
39+
MetaGraphsNext = "0.6, 0.7"
40+
StaticArrays = "1"
41+
AdvancedHMC = "0.6, 0.7, 0.8"
42+
AbstractMCMC = "5"
43+
MCMCChains = "6, 7"
44+
AbstractPPL = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
45+
LogExpFunctions = "0.3"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# AutoMarginalizationExperiments
2+
3+
A lightweight experiment harness (as a package) to showcase finite‑support discrete auto‑marginalization in JuliaBUGS combined with HMC.
4+
5+
Important: Auto-marginalization is on the current JuliaBUGS branch, not the latest release. Develop JuliaBUGS into this environment first:
6+
7+
```
8+
julia --project=experiments/AutoMarginalizationExperiments -e 'using Pkg; Pkg.develop(path="JuliaBUGS"); Pkg.instantiate()'
9+
```
10+
11+
Then try a quick run:
12+
13+
```
14+
julia --project=experiments/AutoMarginalizationExperiments -e 'using Pkg; Pkg.instantiate()'
15+
julia --project=experiments/AutoMarginalizationExperiments -e 'using AutoMarginalizationExperiments; AutoMarginalizationExperiments.run_gmm_autmarg_nuts(2000, 3)'
16+
```
17+
18+
Goals (aligned with experiments/plan.md):
19+
- Exactness/gradient checks on small models (GMM/HMM).
20+
- Scaling vs weighted frontier width and order selection.
21+
- ODE + finite discrete noise (PK Theoph) with reuse ablations.
22+
- Single changepoint over a finite grid.
23+
- “Must‑have” demos: GMM and HMM with NUTS on the marginalized target.
24+
25+
Folders:
26+
- `src/` — package modules: metrics, ordering helpers, synthetic GMM/HMM, NUTS harness.
27+
- Future: `scripts/` for CLI drivers and CSV logging; `pk_theoph/`, `changepoint_step/` subfolders.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/usr/bin/env julia
2+
using AutoMarginalizationExperiments
3+
using JuliaBUGS
4+
using LogDensityProblems
5+
using BenchmarkTools
6+
using Printf
7+
8+
function bench_gmm(; N_list=[1_000, 5_000, 10_000], K=3, reps=10, seed=1)
9+
@printf "GMM logdensity benchmark (auto-marg), K=%d\n" K
10+
for N in N_list
11+
data, _ = AutoMarginalizationExperiments.synth_gmm(N; seed=seed, weights=fill(1/K,K), mus=collect(range(-2,2; length=K)), sigmas=fill(1.0,K))
12+
dataK = (data..., K=K)
13+
mdef = AutoMarginalizationExperiments.build_gmm_model(K)
14+
model, θ = AutoMarginalizationExperiments.compile_autmarg(mdef, dataK)
15+
# warmup
16+
LogDensityProblems.logdensity(model, θ)
17+
b = @benchmark LogDensityProblems.logdensity($model, $θ) samples=$reps evals=1
18+
@printf " N=%6d median=%.3f ms mean=%.3f ms allocs=%d bytes=%d\n" N (median(b).time/1e6) (mean(b).time/1e6) median(b).allocs median(b).memory
19+
end
20+
end
21+
22+
function bench_hmm(; T_list=[200, 500, 1000], reps=10, seed=1)
23+
@printf "HMM logdensity benchmark (auto-marg), S=2\n"
24+
for T in T_list
25+
data, _ = AutoMarginalizationExperiments.synth_hmm_binary(T; seed=seed)
26+
mdef = AutoMarginalizationExperiments.build_hmm2_model()
27+
model, θ = AutoMarginalizationExperiments.compile_autmarg(mdef, data)
28+
LogDensityProblems.logdensity(model, θ)
29+
b = @benchmark LogDensityProblems.logdensity($model, $θ) samples=$reps evals=1
30+
@printf " T=%6d median=%.3f ms mean=%.3f ms allocs=%d bytes=%d\n" T (median(b).time/1e6) (mean(b).time/1e6) median(b).allocs median(b).memory
31+
end
32+
end
33+
34+
bench_gmm()
35+
println()
36+
bench_hmm()
37+
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env julia
2+
using AutoMarginalizationExperiments
3+
using JuliaBUGS
4+
using ADTypes
5+
using Distributions
6+
using ForwardDiff
7+
using LogExpFunctions
8+
using LogDensityProblems
9+
using LogDensityProblemsAD
10+
using Printf
11+
12+
function main(; N=12, K=2, seed=1)
13+
data, truth = AutoMarginalizationExperiments.synth_gmm(N; seed=seed, weights=fill(1/K,K), mus=collect(range(-1.0,1.0; length=K)), sigmas=fill(0.8, K))
14+
dataK = (data..., K=K)
15+
model_def = AutoMarginalizationExperiments.build_gmm_model(K)
16+
model, θ = AutoMarginalizationExperiments.compile_autmarg(model_def, dataK)
17+
ad_model = ADgradient(AutoForwardDiff(), model)
18+
19+
# Build mapping from θ to variable values
20+
gd = model.graph_evaluation_data
21+
# Continuous parameters only
22+
cont_vars = JuliaBUGS.Model.VarName[]
23+
for vn in gd.sorted_parameters
24+
idx = findfirst(==(vn), gd.sorted_nodes)
25+
if idx !== nothing && gd.node_types[idx] == :continuous
26+
push!(cont_vars, vn)
27+
end
28+
end
29+
var_lengths = Dict{JuliaBUGS.Model.VarName,Int}()
30+
for vn in cont_vars
31+
var_lengths[vn] = model.transformed_var_lengths[vn]
32+
end
33+
offsets = Dict{JuliaBUGS.Model.VarName,Int}()
34+
start = 1
35+
for vn in cont_vars
36+
offsets[vn] = start
37+
start += var_lengths[vn]
38+
end
39+
40+
function unpack(θvec)
41+
T = eltype(θvec)
42+
mus = fill(zero(T), K)
43+
sigmas = fill(zero(T), K)
44+
for vn in cont_vars
45+
name = string(vn)
46+
s = offsets[vn]
47+
# parse index inside brackets
48+
idx = try parse(Int, name[findfirst('[', name)+1:findfirst(']', name)-1]) catch; 0 end
49+
if startswith(name, "mu[") && idx 1 && idx K
50+
mus[idx] = θvec[s]
51+
elseif startswith(name, "sigma[") && idx 1 && idx K
52+
sigmas[idx] = exp(θvec[s])
53+
end
54+
end
55+
return mus, sigmas
56+
end
57+
58+
function logjoint_closed(θvec)
59+
mus, sigmas = unpack(θvec)
60+
@assert length(mus) == K && length(sigmas) == K
61+
# Priors: mu ~ Normal(0,5), sigma ~ Exponential(1) with log-Jacobian from exp transform
62+
lp = 0.0
63+
for k in 1:K
64+
lp += logpdf(Distributions.Normal(0,5), mus[k])
65+
lp += logpdf(Distributions.Exponential(1.0), sigmas[k]) + log(sigmas[k]) # jacobian of exp
66+
end
67+
# Likelihood: product over i of sum_k w_k N(y_i | mu_k, sigma_k)
68+
w = 1.0 / K
69+
for yi in data.y
70+
terms = similar(mus)
71+
@inbounds for k in 1:K
72+
terms[k] = log(w) + logpdf(Distributions.Normal(mus[k], sigmas[k]), yi)
73+
end
74+
lp += LogExpFunctions.logsumexp(terms)
75+
end
76+
return lp
77+
end
78+
79+
val_ad, grad_ad = LogDensityProblems.logdensity_and_gradient(ad_model, θ)
80+
val_cf = logjoint_closed(θ)
81+
grad_cf = ForwardDiff.gradient(logjoint_closed, θ)
82+
83+
@printf "N=%d, K=%d\n" N K
84+
@printf "value: engine=%.8f, closed=%.8f, absdiff=%.3e\n" val_ad val_cf abs(val_ad - val_cf)
85+
@printf "grad max-abs-diff: %.3e\n" maximum(abs.(grad_ad .- grad_cf))
86+
end
87+
88+
main()
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env julia
2+
using AutoMarginalizationExperiments
3+
using JuliaBUGS
4+
using LogDensityProblems
5+
using Printf
6+
using Random
7+
8+
function hmm_logdensity_with_order(model, θ, order)
9+
# Prepare caches and offsets as in the benchmark script
10+
gd = model.graph_evaluation_data
11+
minimal_keys = AutoMarginalizationExperiments.prepare_minimal_cache_keys(model, order)
12+
# Build continuous-only param order and offsets
13+
cont_vars = JuliaBUGS.Model.VarName[]
14+
var_lengths = Dict{JuliaBUGS.Model.VarName,Int}()
15+
for vn in gd.sorted_parameters
16+
idx = findfirst(==(vn), gd.sorted_nodes)
17+
if idx !== nothing && gd.node_types[idx] == :continuous
18+
push!(cont_vars, vn)
19+
var_lengths[vn] = model.transformed_var_lengths[vn]
20+
end
21+
end
22+
offsets = Dict{JuliaBUGS.Model.VarName,Int}()
23+
start = 1
24+
for vn in cont_vars
25+
offsets[vn] = start
26+
start += var_lengths[vn]
27+
end
28+
env = JuliaBUGS.Model.smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols)
29+
memo = Dict{Tuple{Int,UInt64},Any}()
30+
return JuliaBUGS.Model._marginalize_recursive(
31+
model, env, order, θ, offsets, var_lengths, memo, minimal_keys,
32+
)
33+
end
34+
35+
function peak_frontier_size(minimal_keys)
36+
isempty(minimal_keys) && return 0
37+
return maximum((length(v) for v in values(minimal_keys)))
38+
end
39+
40+
function main(; T=300, reps=10, seed=1)
41+
rng = MersenneTwister(seed)
42+
data, _ = AutoMarginalizationExperiments.synth_hmm_binary(T; seed=seed)
43+
model_def = AutoMarginalizationExperiments.build_hmm2_model()
44+
model, θ = AutoMarginalizationExperiments.compile_autmarg(model_def, data)
45+
46+
gd = model.graph_evaluation_data
47+
default_order = isempty(gd.marginalization_order) ? collect(1:length(gd.sorted_nodes)) : gd.marginalization_order
48+
interleaved = AutoMarginalizationExperiments.build_interleaved_order(model)
49+
50+
# Warmup
51+
hmm_logdensity_with_order(model, θ, default_order)
52+
hmm_logdensity_with_order(model, θ, interleaved)
53+
54+
# Measure
55+
function timeit(order)
56+
t = @elapsed begin
57+
for _ in 1:reps
58+
hmm_logdensity_with_order(model, θ, order)
59+
end
60+
end
61+
mk = AutoMarginalizationExperiments.prepare_minimal_cache_keys(model, order)
62+
return t, peak_frontier_size(mk)
63+
end
64+
65+
t_def, w_def = timeit(default_order)
66+
t_int, w_int = timeit(interleaved)
67+
68+
l_def = hmm_logdensity_with_order(model, θ, default_order)
69+
l_int = hmm_logdensity_with_order(model, θ, interleaved)
70+
71+
@printf "HMM ordering ablation (T=%d, reps=%d)\n" T reps
72+
@printf " default: time=%.4f s, peak_frontier=%d, logp=%.6f\n" t_def w_def l_def
73+
@printf " interlv: time=%.4f s, peak_frontier=%d, logp=%.6f\n" t_int w_int l_int
74+
@printf " abs diff in logp = %.3e\n" abs(l_def - l_int)
75+
end
76+
77+
main()
78+
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env julia
2+
using Pkg
3+
using Printf
4+
5+
root = normpath(joinpath(@__DIR__, "..", "..", ".."))
6+
jbugs = joinpath(root, "JuliaBUGS")
7+
@printf "Developing JuliaBUGS from %s\n" jbugs
8+
Pkg.develop(path=jbugs)
9+
Pkg.instantiate()
10+
println("OK")
11+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module AutoMarginalizationExperiments
2+
3+
using Random
4+
using LinearAlgebra
5+
using Statistics
6+
using Printf
7+
8+
using Distributions
9+
using ADTypes
10+
using LogDensityProblems
11+
using LogDensityProblemsAD
12+
13+
using JuliaBUGS
14+
using JuliaBUGS: @bugs, compile, settrans
15+
import JuliaBUGS.Model
16+
17+
include("metrics.jl")
18+
include("ordering.jl")
19+
include("synth_gmm.jl")
20+
include("synth_hmm.jl")
21+
include("harness.jl")
22+
23+
export
24+
# Metrics
25+
Metrics,
26+
# GMM
27+
synth_gmm, build_gmm_model, run_gmm_autmarg_nuts,
28+
# HMM
29+
synth_hmm_binary, build_hmm2_model, run_hmm_autmarg_nuts,
30+
# Ordering helpers
31+
build_interleaved_order, prepare_minimal_cache_keys
32+
33+
end # module

0 commit comments

Comments
 (0)