Skip to content

Commit 6b1206e

Browse files
committed
test on GMM
1 parent 3c0504e commit 6b1206e

File tree

1 file changed

+364
-0
lines changed

1 file changed

+364
-0
lines changed

JuliaBUGS/test/model/auto_marginalization.jl

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,370 @@ using JuliaBUGS.Model: set_evaluation_mode, UseAutoMarginalization, UseGraph
202202
@test discrete_count == 3 # z[1], z[2], z[3]
203203
end
204204

205+
@testset "Gaussian Mixture Models" begin
206+
# Helper function for ground truth mixture likelihood
207+
function mixture_loglikelihood(y, weights, mus, sigmas)
208+
n = length(y)
209+
k = length(weights)
210+
logp_total = 0.0
211+
212+
for i in 1:n
213+
# Log-sum-exp over components for each observation
214+
log_probs = zeros(k)
215+
for j in 1:k
216+
log_probs[j] = log(weights[j]) + logpdf(Normal(mus[j], sigmas[j]), y[i])
217+
end
218+
logp_total += LogExpFunctions.logsumexp(log_probs)
219+
end
220+
221+
return logp_total
222+
end
223+
224+
@testset "Two-component mixture with fixed weights" begin
225+
# Simple mixture with fixed mixture weights
226+
mixture_fixed_def = @bugs begin
227+
# Fixed mixture weights
228+
w[1] = 0.3
229+
w[2] = 0.7
230+
231+
# Component parameters
232+
mu[1] ~ Normal(-2, 5)
233+
mu[2] ~ Normal(2, 5)
234+
sigma[1] ~ Exponential(1)
235+
sigma[2] ~ Exponential(1)
236+
237+
# Component assignments (discrete, to be marginalized)
238+
for i in 1:N
239+
z[i] ~ Categorical(w[1:2])
240+
y[i] ~ Normal(mu[z[i]], sigma[z[i]])
241+
end
242+
end
243+
244+
N = 4
245+
y_obs = [-1.5, 2.3, -2.1, 1.8]
246+
data = (N=N, y=y_obs)
247+
248+
model = compile(mixture_fixed_def, data)
249+
model = settrans(model, true)
250+
model = set_evaluation_mode(model, UseAutoMarginalization())
251+
252+
# Should have 4 continuous parameters: sigma[1], sigma[2], mu[2], mu[1]
253+
@test LogDensityProblems.dimension(model) == 4
254+
255+
# Test with specific parameters
256+
# Order: log(sigma[1]), log(sigma[2]), mu[2], mu[1]
257+
test_params = [0.0, 0.0, 2.0, -2.0] # sigmas=1, mu[2]=2, mu[1]=-2
258+
259+
logp_marginalized = LogDensityProblems.logdensity(model, test_params)
260+
261+
# Compute expected value
262+
weights = [0.3, 0.7]
263+
mus = [-2.0, 2.0]
264+
sigmas = [1.0, 1.0]
265+
266+
logp_likelihood = mixture_loglikelihood(y_obs, weights, mus, sigmas)
267+
prior_logp =
268+
logpdf(Normal(-2, 5), -2.0) +
269+
logpdf(Normal(2, 5), 2.0) +
270+
logpdf(Exponential(1), 1.0) +
271+
logpdf(Exponential(1), 1.0)
272+
expected = logp_likelihood + prior_logp
273+
274+
@test isapprox(logp_marginalized, expected; atol=1e-10)
275+
end
276+
277+
@testset "Three-component mixture with fixed weights" begin
278+
# Extend to 3 components with exact verification
279+
mixture_3comp_def = @bugs begin
280+
# Fixed mixture weights
281+
w[1] = 0.2
282+
w[2] = 0.5
283+
w[3] = 0.3
284+
285+
# Component parameters
286+
mu[1] ~ Normal(-3, 5)
287+
mu[2] ~ Normal(0, 5)
288+
mu[3] ~ Normal(3, 5)
289+
for k in 1:3
290+
sigma[k] ~ Exponential(1)
291+
end
292+
293+
# Component assignments
294+
for i in 1:N
295+
z[i] ~ Categorical(w[1:3])
296+
y[i] ~ Normal(mu[z[i]], sigma[z[i]])
297+
end
298+
end
299+
300+
N = 3
301+
y_obs = [-2.5, 0.5, 3.2]
302+
data = (N=N, y=y_obs)
303+
304+
model = compile(mixture_3comp_def, data)
305+
model = settrans(model, true)
306+
model = set_evaluation_mode(model, UseAutoMarginalization())
307+
308+
# Should have 6 continuous parameters: 3 sigmas + 3 mus
309+
@test LogDensityProblems.dimension(model) == 6
310+
311+
# Test with specific parameters
312+
test_params = [0.0, 0.0, 0.0, 3.0, 0.0, -3.0]
313+
# log(sigmas)=0 -> all sigmas=1, mu[3]=3, mu[2]=0, mu[1]=-3
314+
315+
logp_marginalized = LogDensityProblems.logdensity(model, test_params)
316+
317+
# Compute expected value
318+
weights = [0.2, 0.5, 0.3]
319+
mus = [-3.0, 0.0, 3.0]
320+
sigmas = [1.0, 1.0, 1.0]
321+
322+
logp_likelihood = mixture_loglikelihood(y_obs, weights, mus, sigmas)
323+
prior_logp = sum([
324+
logpdf(Normal(-3, 5), -3.0),
325+
logpdf(Normal(0, 5), 0.0),
326+
logpdf(Normal(3, 5), 3.0),
327+
logpdf(Exponential(1), 1.0),
328+
logpdf(Exponential(1), 1.0),
329+
logpdf(Exponential(1), 1.0),
330+
])
331+
expected = logp_likelihood + prior_logp
332+
333+
@test isapprox(logp_marginalized, expected; atol=1e-10)
334+
end
335+
336+
@testset "Label invariance" begin
337+
# Verify that permuting component labels doesn't change log-density
338+
# when weights are equal
339+
mixture_sym_def = @bugs begin
340+
w[1] = 0.5
341+
w[2] = 0.5
342+
343+
mu[1] ~ Normal(0, 10)
344+
mu[2] ~ Normal(0, 10)
345+
sigma[1] ~ Exponential(1)
346+
sigma[2] ~ Exponential(1)
347+
348+
for i in 1:N
349+
z[i] ~ Categorical(w[1:2])
350+
y[i] ~ Normal(mu[z[i]], sigma[z[i]])
351+
end
352+
end
353+
354+
N = 4
355+
y_obs = [1.0, 2.0, -1.0, 3.0]
356+
data = (N=N, y=y_obs)
357+
358+
model = compile(mixture_sym_def, data)
359+
model = settrans(model, true)
360+
model = set_evaluation_mode(model, UseAutoMarginalization())
361+
362+
# Test with original ordering
363+
# Order: log(sigma[1]), log(sigma[2]), mu[2], mu[1]
364+
params1 = [-0.5, 0.0, 3.0, 1.0] # sigma[1]=exp(-0.5), sigma[2]=1, mu[2]=3, mu[1]=1
365+
logp1 = LogDensityProblems.logdensity(model, params1)
366+
367+
# Test with swapped components (swap mu and sigma values)
368+
params2 = [0.0, -0.5, 1.0, 3.0] # sigma[1]=1, sigma[2]=exp(-0.5), mu[2]=1, mu[1]=3
369+
logp2 = LogDensityProblems.logdensity(model, params2)
370+
371+
# The log probabilities should be equal due to symmetry
372+
# (swapping components 1 and 2 completely with equal weights)
373+
@test isapprox(logp1, logp2; atol=1e-10)
374+
end
375+
376+
@testset "Partial observation of z" begin
377+
# Some z[i] are observed, others are marginalized
378+
mixture_partial_def = @bugs begin
379+
w[1] = 0.3
380+
w[2] = 0.7
381+
382+
mu[1] ~ Normal(-2, 5)
383+
mu[2] ~ Normal(2, 5)
384+
sigma ~ Exponential(1) # Shared sigma
385+
386+
for i in 1:N
387+
z[i] ~ Categorical(w[1:2])
388+
y[i] ~ Normal(mu[z[i]], sigma)
389+
end
390+
end
391+
392+
N = 4
393+
# Observe z[1] and z[3], marginalize z[2] and z[4]
394+
data = (N=N, y=[1.0, 2.0, -1.0, 3.0], z=[2, missing, 1, missing])
395+
396+
model = compile(mixture_partial_def, data)
397+
model = settrans(model, true)
398+
model = set_evaluation_mode(model, UseAutoMarginalization())
399+
400+
# Should have 3 continuous parameters: sigma, mu[2], mu[1]
401+
# z[2] and z[4] are marginalized out
402+
@test LogDensityProblems.dimension(model) == 3
403+
404+
# Test evaluation
405+
test_params = [0.0, 2.0, -2.0] # log(sigma)=0->sigma=1, mu[2]=2, mu[1]=-2
406+
logp = LogDensityProblems.logdensity(model, test_params)
407+
408+
# Verify it's finite and reasonable
409+
@test isfinite(logp)
410+
@test logp < 0
411+
412+
# Manually compute expected for observed components
413+
# z[1]=2 -> y[1]=1.0 comes from mu[2]=2
414+
# z[3]=1 -> y[3]=-1.0 comes from mu[1]=-2
415+
# z[2] and z[4] are marginalized
416+
sigma_val = 1.0
417+
mu_vals = [-2.0, 2.0]
418+
weights = [0.3, 0.7]
419+
420+
# Observed parts
421+
logp_obs = (
422+
log(weights[2]) +
423+
logpdf(Normal(mu_vals[2], sigma_val), 1.0) + # z[1]=2, y[1]=1.0
424+
log(weights[1]) +
425+
logpdf(Normal(mu_vals[1], sigma_val), -1.0) # z[3]=1, y[3]=-1.0
426+
)
427+
428+
# Marginalized parts for y[2]=2.0 and y[4]=3.0
429+
logp_marg2 = LogExpFunctions.logsumexp([
430+
log(weights[1]) + logpdf(Normal(mu_vals[1], sigma_val), 2.0),
431+
log(weights[2]) + logpdf(Normal(mu_vals[2], sigma_val), 2.0),
432+
])
433+
logp_marg4 = LogExpFunctions.logsumexp([
434+
log(weights[1]) + logpdf(Normal(mu_vals[1], sigma_val), 3.0),
435+
log(weights[2]) + logpdf(Normal(mu_vals[2], sigma_val), 3.0),
436+
])
437+
438+
logp_likelihood = logp_obs + logp_marg2 + logp_marg4
439+
prior_logp = (
440+
logpdf(Normal(-2, 5), -2.0) +
441+
logpdf(Normal(2, 5), 2.0) +
442+
logpdf(Exponential(1), 1.0)
443+
)
444+
expected = logp_likelihood + prior_logp
445+
446+
@test isapprox(logp, expected; atol=1e-10)
447+
end
448+
449+
@testset "Mixture with Dirichlet prior on weights" begin
450+
# More realistic mixture with learned weights
451+
mixture_dirichlet_def = @bugs begin
452+
# Mixture weights with Dirichlet prior
453+
alpha[1] = 1.0
454+
alpha[2] = 1.0
455+
alpha[3] = 1.0
456+
w[1:3] ~ ddirich(alpha[1:3])
457+
458+
# Component parameters
459+
for k in 1:3
460+
mu[k] ~ Normal(0, 10)
461+
sigma[k] ~ Exponential(1)
462+
end
463+
464+
# Component assignments
465+
for i in 1:N
466+
z[i] ~ Categorical(w[1:3])
467+
y[i] ~ Normal(mu[z[i]], sigma[z[i]])
468+
end
469+
end
470+
471+
N = 5
472+
y_obs = [-3.0, 0.1, 2.9, -2.8, 3.1]
473+
data = (N=N, y=y_obs)
474+
475+
model = compile(mixture_dirichlet_def, data)
476+
model = settrans(model, true)
477+
model = set_evaluation_mode(model, UseAutoMarginalization())
478+
479+
# Should have 8 continuous parameters:
480+
# 3 sigmas + 3 mus + 2 transformed weight components (3-1 due to simplex constraint)
481+
@test LogDensityProblems.dimension(model) == 8
482+
483+
# Test with specific parameters
484+
# Simplex transform for weights [0.2, 0.3, 0.5]
485+
# Using stick-breaking: w1=0.2, w2=0.3, w3=0.5
486+
# This requires specific transformed values
487+
w_target = [0.2, 0.3, 0.5]
488+
# For Dirichlet, use log-ratio transform
489+
log_ratios = [log(w_target[1] / w_target[3]), log(w_target[2] / w_target[3])]
490+
491+
test_params = [
492+
0.0,
493+
0.0,
494+
0.0, # log(sigmas) = 0 -> all sigmas = 1
495+
3.0,
496+
0.0,
497+
-3.0, # mu[3]=3, mu[2]=0, mu[1]=-3
498+
log_ratios[1],
499+
log_ratios[2], # transformed weights
500+
]
501+
502+
logp_marginalized = LogDensityProblems.logdensity(model, test_params)
503+
504+
# Verify it's finite and reasonable
505+
@test isfinite(logp_marginalized)
506+
@test logp_marginalized < 0 # Should be negative for realistic parameters
507+
end
508+
509+
@testset "Hierarchical mixture model" begin
510+
# Mixture with hierarchical structure on component means
511+
hierarchical_mixture_def = @bugs begin
512+
# Hyperpriors
513+
mu_global ~ Normal(0, 10)
514+
tau_global ~ Exponential(1)
515+
516+
# Mixture weights
517+
w[1] = 0.5
518+
w[2] = 0.5
519+
520+
# Component-specific parameters with hierarchical prior
521+
for k in 1:2
522+
mu[k] ~ Normal(mu_global, tau_global)
523+
sigma[k] ~ Exponential(1)
524+
end
525+
526+
# Data generation
527+
for i in 1:N
528+
z[i] ~ Categorical(w[1:2])
529+
y[i] ~ Normal(mu[z[i]], sigma[z[i]])
530+
end
531+
end
532+
533+
N = 6
534+
y_obs = [1.0, 1.2, 4.8, 5.1, 0.9, 5.0]
535+
data = (N=N, y=y_obs)
536+
537+
model = compile(hierarchical_mixture_def, data)
538+
model = settrans(model, true)
539+
model = set_evaluation_mode(model, UseAutoMarginalization())
540+
541+
# Should have 6 continuous parameters:
542+
# mu_global, tau_global, 2 sigmas, 2 mus
543+
@test LogDensityProblems.dimension(model) == 6
544+
545+
# Test evaluation with multiple parameter sets
546+
# Test 1: Parameters that should give reasonable likelihood
547+
test_params = [3.0, 0.0, 0.0, 0.0, 5.0, 1.0]
548+
# mu_global=3, log(tau_global)=0->tau=1, log(sigmas)=0->sigmas=1, mu[2]=5, mu[1]=1
549+
550+
logp_marginalized = LogDensityProblems.logdensity(model, test_params)
551+
552+
# Verify the result is finite and reasonable
553+
@test isfinite(logp_marginalized)
554+
@test logp_marginalized < 0 # Log probability should be negative
555+
556+
# Test 2: Different parameters - should give different likelihood
557+
test_params2 = [2.5, -0.5, -0.5, 0.2, 4.5, 0.5]
558+
logp_marginalized2 = LogDensityProblems.logdensity(model, test_params2)
559+
560+
@test isfinite(logp_marginalized2)
561+
@test logp_marginalized2 != logp_marginalized # Different params should give different results
562+
563+
# Test 3: Verify multiple evaluations are consistent
564+
logp_repeat = LogDensityProblems.logdensity(model, test_params)
565+
@test logp_repeat == logp_marginalized # Same params should give same result
566+
end
567+
end
568+
205569
@testset "Edge cases" begin
206570
@testset "Model with no discrete finite variables" begin
207571
# Simple continuous model - marginalization should work but do nothing special

0 commit comments

Comments
 (0)