Skip to content

Commit ff09591

Browse files
committed
Restructure Gibbs inference tests and reduce iteration counts
1 parent fa81f83 commit ff09591

File tree

1 file changed

+66
-55
lines changed

1 file changed

+66
-55
lines changed

test/mcmc/gibbs.jl

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -337,61 +337,72 @@ end
337337
@test sample(gdemo_default, g, N) isa MCMCChains.Chains
338338
end
339339

340-
@testset "gibbs inference" begin
341-
Random.seed!(100)
342-
alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend))
343-
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
344-
check_numerical(chain, [:m], [7 / 6]; atol=0.15)
345-
# Be more relaxed with the tolerance of the variance.
346-
check_numerical(chain, [:s], [49 / 24]; atol=0.35)
347-
348-
Random.seed!(100)
349-
350-
alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend))
351-
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
352-
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
353-
354-
alg = Gibbs(; s=CSMC(15), m=ESS())
355-
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
356-
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
357-
358-
alg = CSMC(15)
359-
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
360-
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
361-
362-
Random.seed!(200)
363-
gibbs = Gibbs(
364-
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
365-
(@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
366-
)
367-
chain = sample(MoGtest_default, gibbs, 10_000)
368-
check_MoGtest_default(chain; atol=0.15)
369-
370-
Random.seed!(200)
371-
# Test samplers that are run multiple times, or have overlapping targets.
372-
alg = Gibbs(
373-
@varname(s) => MH(),
374-
(@varname(s), @varname(m)) => MH(),
375-
@varname(m) => ESS(),
376-
@varname(s) => RepeatSampler(MH(), 3),
377-
@varname(m) => HMC(0.2, 4; adtype=adbackend),
378-
(@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend),
379-
)
380-
chain = sample(gdemo(1.5, 2.0), alg, 300)
381-
check_gdemo(chain; atol=0.15)
382-
383-
Random.seed!(200)
384-
gibbs = Gibbs(
385-
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
386-
(@varname(z1), @varname(z2)) => PG(15),
387-
(@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
388-
(@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2),
389-
(@varname(mu1)) => ESS(),
390-
(@varname(mu2)) => ESS(),
391-
(@varname(z1), @varname(z2)) => PG(15),
392-
)
393-
chain = sample(MoGtest_default, gibbs, 300)
394-
check_MoGtest_default(chain; atol=0.15)
340+
# Test various combinations of samplers against models for which we know the analytical
341+
# posterior mean.
342+
@testset "Gibbs inference" begin
343+
@testset "CSMC and HMC on gdemo" begin
344+
alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend))
345+
chain = sample(gdemo(1.5, 2.0), alg, 3_000)
346+
check_numerical(chain, [:m], [7 / 6]; atol=0.15)
347+
# Be more relaxed with the tolerance of the variance.
348+
check_numerical(chain, [:s], [49 / 24]; atol=0.35)
349+
end
350+
351+
@testset "MH and HMCDA on gdemo" begin
352+
alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend))
353+
chain = sample(gdemo(1.5, 2.0), alg, 3_000)
354+
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
355+
end
356+
357+
@testset "CSMC and ESS on gdemo" begin
358+
alg = Gibbs(; s=CSMC(15), m=ESS())
359+
chain = sample(gdemo(1.5, 2.0), alg, 3_000)
360+
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
361+
end
362+
363+
# TODO(mhauru) Why is this in the Gibbs test suite?
364+
@testset "CSMC on gdemo" begin
365+
alg = CSMC(15)
366+
chain = sample(gdemo(1.5, 2.0), alg, 4_000)
367+
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
368+
end
369+
370+
@testset "PG and HMC on MoGtest_default" begin
371+
gibbs = Gibbs(
372+
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
373+
(@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
374+
)
375+
chain = sample(MoGtest_default, gibbs, 2_000)
376+
check_MoGtest_default(chain; atol=0.15)
377+
end
378+
379+
@testset "Multiple overlapping samplers on gdemo" begin
380+
# Test samplers that are run multiple times, or have overlapping targets.
381+
alg = Gibbs(
382+
@varname(s) => MH(),
383+
(@varname(s), @varname(m)) => MH(),
384+
@varname(m) => ESS(),
385+
@varname(s) => RepeatSampler(MH(), 3),
386+
@varname(m) => HMC(0.2, 4; adtype=adbackend),
387+
(@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend),
388+
)
389+
chain = sample(gdemo(1.5, 2.0), alg, 500)
390+
check_gdemo(chain; atol=0.15)
391+
end
392+
393+
@testset "Multiple overlapping samplers on MoGtest_default" begin
394+
gibbs = Gibbs(
395+
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
396+
(@varname(z1), @varname(z2)) => PG(15),
397+
(@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
398+
(@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2),
399+
(@varname(mu1)) => ESS(),
400+
(@varname(mu2)) => ESS(),
401+
(@varname(z1), @varname(z2)) => PG(15),
402+
)
403+
chain = sample(MoGtest_default, gibbs, 500)
404+
check_MoGtest_default(chain; atol=0.15)
405+
end
395406
end
396407

397408
@testset "transitions" begin

0 commit comments

Comments
 (0)