|
| 1 | +using JuliaBUGS: @bugs, compile, settrans, getparams, initialize! |
| 2 | +using JuliaBUGS.Model: set_evaluation_mode, UseAutoMarginalization, parameters, evaluate_with_marginalization_values!! |
| 3 | + |
| 4 | +@testset "Auto-Marginalization Sampling (NUTS)" begin |
| 5 | + # 2-component GMM with fixed weights. Discrete z marginalized out. |
| 6 | + mixture_def = @bugs begin |
| 7 | + w[1] = 0.3 |
| 8 | + w[2] = 0.7 |
| 9 | + |
| 10 | + # Moderately informative priors to aid identifiability |
| 11 | + mu[1] ~ Normal(-2, 1) |
| 12 | + mu[2] ~ Normal(2, 1) |
| 13 | + sigma[1] ~ Exponential(1) |
| 14 | + sigma[2] ~ Exponential(1) |
| 15 | + |
| 16 | + for i in 1:N |
| 17 | + z[i] ~ Categorical(w[1:2]) |
| 18 | + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) |
| 19 | + end |
| 20 | + end |
| 21 | + |
| 22 | + # Generate data from the ground-truth parameters |
| 23 | + N = 120 |
| 24 | + true_w = [0.3, 0.7] |
| 25 | + true_mu = [-2.0, 2.0] |
| 26 | + true_sigma = [1.0, 1.0] |
| 27 | + rng = StableRNG(1234) |
| 28 | + # Partially observed assignments to break label switching and speed convergence |
| 29 | + z_full = Vector{Int}(undef, N) |
| 30 | + z_obs = Vector{Union{Int,Missing}}(undef, N) |
| 31 | + # First 30 guaranteed component 1, last 30 guaranteed component 2 |
| 32 | + for i in 1:30 |
| 33 | + z_full[i] = 1 |
| 34 | + z_obs[i] = 1 |
| 35 | + end |
| 36 | + for i in N-29:N |
| 37 | + z_full[i] = 2 |
| 38 | + z_obs[i] = 2 |
| 39 | + end |
| 40 | + # Middle indices drawn randomly |
| 41 | + for i in 31:N-30 |
| 42 | + z_full[i] = rand(rng, Categorical(true_w)) |
| 43 | + z_obs[i] = missing |
| 44 | + end |
| 45 | + # Generate y |
| 46 | + y = Vector{Float64}(undef, N) |
| 47 | + for i in 1:N |
| 48 | + y[i] = rand(rng, Normal(true_mu[z_full[i]], true_sigma[z_full[i]])) |
| 49 | + end |
| 50 | + |
| 51 | + data = (N=N, y=y, z=z_obs) |
| 52 | + |
| 53 | + # Compile auto-marginalized model and wrap with AD for NUTS |
| 54 | + model = compile(mixture_def, data) |> m -> settrans(m, true) |> m -> set_evaluation_mode(m, UseAutoMarginalization()) |
| 55 | + # Initialize near ground truth for faster convergence |
| 56 | + initialize!(model, (; mu=[-2.0, 2.0], sigma=[1.0, 1.0])) |
| 57 | + ad_model = ADgradient(AutoForwardDiff(), model) |
| 58 | + |
| 59 | + # Initialize at current transformed parameters |
| 60 | + θ0 = getparams(model) |
| 61 | + |
| 62 | + # Short NUTS run to verify we recover means reasonably well |
| 63 | + # Use more samples to tighten estimation accuracy |
| 64 | + n_samples, n_adapts = 2000, 1000 |
| 65 | + # Sample transitions (avoid requiring MCMCChains conversion here) |
| 66 | + chain = AbstractMCMC.sample( |
| 67 | + rng, |
| 68 | + ad_model, |
| 69 | + NUTS(0.65), |
| 70 | + n_samples; |
| 71 | + progress=false, |
| 72 | + chain_type=MCMCChains.Chains, |
| 73 | + n_adapts=n_adapts, |
| 74 | + init_params=θ0, |
| 75 | + discard_initial=n_adapts, |
| 76 | + ) |
| 77 | + |
| 78 | + # Estimate means directly from Chains |
| 79 | + means = mean(chain) |
| 80 | + mu1_hat = means[Symbol("mu[1]")].nt.mean[1] |
| 81 | + mu2_hat = means[Symbol("mu[2]")].nt.mean[1] |
| 82 | + |
| 83 | + # With unequal weights (0.3 vs 0.7), label switching is unlikely; allow generous tolerance |
| 84 | + # Direct comparison to ground truth with absolute tolerance |
| 85 | + @test isapprox(mu1_hat, true_mu[1]; atol=0.20) |
| 86 | + @test isapprox(mu2_hat, true_mu[2]; atol=0.20) |
| 87 | +end |
0 commit comments