Skip to content

Commit 0445a28

Browse files
committed
add sampling tests
1 parent 6de0807 commit 0445a28

File tree

3 files changed

+103
-4
lines changed

3 files changed

+103
-4
lines changed

JuliaBUGS/test/model/auto_marginalization.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -720,12 +720,23 @@ end
720720
# Auto-marginalized model with small-step NUTS
721721
model_marg = compile(mixture_def, data) |> m -> settrans(m, true) |> m -> set_evaluation_mode(m, UseAutoMarginalization())
722722
@test LogDensityProblems.dimension(model_marg) < LogDensityProblems.dimension(model_graph)
723+
# Run gradient-based sampling (NUTS) on the auto-marginalized AD-wrapped model
723724
ad_model = ADgradient(AutoForwardDiff(), model_marg)
724725
D = LogDensityProblems.dimension(model_marg)
725726
θ0 = zeros(D)
726-
println("[AutoMargTest] Efficiency smoke: skipping AutoMarg+NUTS sampling for now"); flush(stdout)
727-
# Quick sanity: logdensity on AD-wrapped model at θ0 is finite
728-
val = LogDensityProblems.logdensity(ad_model, θ0)
729-
@test isfinite(val)
727+
println("[AutoMargTest] Efficiency smoke: sampling AutoMarg+NUTS..."); flush(stdout)
728+
samps = AbstractMCMC.sample(
729+
Random.default_rng(),
730+
ad_model,
731+
NUTS(0.65),
732+
10;
733+
progress=false,
734+
n_adapts=0,
735+
init_params=θ0,
736+
discard_initial=0,
737+
)
738+
println("[AutoMargTest] Efficiency smoke: AutoMarg+NUTS done"); flush(stdout)
739+
# Ensure sampling executed without errors
740+
@test !isnothing(samps)
730741
end
731742
end
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using JuliaBUGS: @bugs, compile, settrans, getparams, initialize!
2+
using JuliaBUGS.Model: set_evaluation_mode, UseAutoMarginalization, parameters, evaluate_with_marginalization_values!!
3+
4+
@testset "Auto-Marginalization Sampling (NUTS)" begin
5+
# 2-component GMM with fixed weights. Discrete z marginalized out.
6+
mixture_def = @bugs begin
7+
w[1] = 0.3
8+
w[2] = 0.7
9+
10+
# Moderately informative priors to aid identifiability
11+
mu[1] ~ Normal(-2, 1)
12+
mu[2] ~ Normal(2, 1)
13+
sigma[1] ~ Exponential(1)
14+
sigma[2] ~ Exponential(1)
15+
16+
for i in 1:N
17+
z[i] ~ Categorical(w[1:2])
18+
y[i] ~ Normal(mu[z[i]], sigma[z[i]])
19+
end
20+
end
21+
22+
# Generate data from the ground-truth parameters
23+
N = 120
24+
true_w = [0.3, 0.7]
25+
true_mu = [-2.0, 2.0]
26+
true_sigma = [1.0, 1.0]
27+
rng = StableRNG(1234)
28+
# Partially observed assignments to break label switching and speed convergence
29+
z_full = Vector{Int}(undef, N)
30+
z_obs = Vector{Union{Int,Missing}}(undef, N)
31+
# First 30 guaranteed component 1, last 30 guaranteed component 2
32+
for i in 1:30
33+
z_full[i] = 1
34+
z_obs[i] = 1
35+
end
36+
for i in N-29:N
37+
z_full[i] = 2
38+
z_obs[i] = 2
39+
end
40+
# Middle indices drawn randomly
41+
for i in 31:N-30
42+
z_full[i] = rand(rng, Categorical(true_w))
43+
z_obs[i] = missing
44+
end
45+
# Generate y
46+
y = Vector{Float64}(undef, N)
47+
for i in 1:N
48+
y[i] = rand(rng, Normal(true_mu[z_full[i]], true_sigma[z_full[i]]))
49+
end
50+
51+
data = (N=N, y=y, z=z_obs)
52+
53+
# Compile auto-marginalized model and wrap with AD for NUTS
54+
model = compile(mixture_def, data) |> m -> settrans(m, true) |> m -> set_evaluation_mode(m, UseAutoMarginalization())
55+
# Initialize near ground truth for faster convergence
56+
initialize!(model, (; mu=[-2.0, 2.0], sigma=[1.0, 1.0]))
57+
ad_model = ADgradient(AutoForwardDiff(), model)
58+
59+
# Initialize at current transformed parameters
60+
θ0 = getparams(model)
61+
62+
# Short NUTS run to verify we recover means reasonably well
63+
# Use more samples to tighten estimation accuracy
64+
n_samples, n_adapts = 2000, 1000
65+
# Sample transitions (avoid requiring MCMCChains conversion here)
66+
chain = AbstractMCMC.sample(
67+
rng,
68+
ad_model,
69+
NUTS(0.65),
70+
n_samples;
71+
progress=false,
72+
chain_type=MCMCChains.Chains,
73+
n_adapts=n_adapts,
74+
init_params=θ0,
75+
discard_initial=n_adapts,
76+
)
77+
78+
# Estimate means directly from Chains
79+
means = mean(chain)
80+
mu1_hat = means[Symbol("mu[1]")].nt.mean[1]
81+
mu2_hat = means[Symbol("mu[2]")].nt.mean[1]
82+
83+
# With unequal weights (0.3 vs 0.7), label switching is unlikely; allow generous tolerance
84+
# Direct comparison to ground truth with absolute tolerance
85+
@test isapprox(mu1_hat, true_mu[1]; atol=0.20)
86+
@test isapprox(mu2_hat, true_mu[2]; atol=0.20)
87+
end

JuliaBUGS/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ const TEST_GROUPS = OrderedDict{String,Function}(
7979
include("independent_mh.jl")
8080
include("ext/JuliaBUGSAdvancedHMCExt.jl")
8181
include("ext/JuliaBUGSMCMCChainsExt.jl")
82+
include("model/auto_marginalization_sampling.jl")
8283
end,
8384
"inference_hmc" => () -> include("ext/JuliaBUGSAdvancedHMCExt.jl"),
8485
"inference_chains" => () -> include("ext/JuliaBUGSMCMCChainsExt.jl"),

0 commit comments

Comments
 (0)