From d8f6dfa10bb6842202f4ba9c4527d07d63231051 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 23 May 2025 12:36:36 +0100 Subject: [PATCH 01/18] Remove AD backend loops --- test/mcmc/abstractmcmc.jl | 58 +++++++++++++++++++-------------------- test/mcmc/sghmc.jl | 14 +++++----- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index f909df8ef1..5cb8330259 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -112,36 +112,36 @@ end @testset "External samplers" begin @testset "AdvancedHMC.jl" begin - @testset "adtype=$adtype" for adtype in ADUtils.adbackends - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - # Need some functionality to initialize the sampler. - # TODO: Remove this once the constructors in the respective packages become "lazy". - sampler = initialize_nuts(model) - sampler_ext = DynamicPPL.Sampler( - externalsampler(sampler; adtype, unconstrained=true) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + adtype = Turing.DEFAULT_ADTYPE + + # Need some functionality to initialize the sampler. + # TODO: Remove this once the constructors in the respective packages become "lazy". + sampler = initialize_nuts(model) + sampler_ext = DynamicPPL.Sampler( + externalsampler(sampler; adtype, unconstrained=true) + ) + # FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment. + # @testset "initial_params" begin + # test_initial_params(model, sampler_ext; n_adapts=0) + # end + + sample_kwargs = ( + n_adapts=1_000, + discard_initial=1_000, + # FIXME: Remove this once we can run `test_initial_params` above. + initial_params=DynamicPPL.VarInfo(model)[:], + ) + + @testset "inference" begin + DynamicPPL.TestUtils.test_sampler( + [model], + sampler_ext, + 2_000; + rtol=0.2, + sampler_name="AdvancedHMC", + sample_kwargs..., ) - # FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment. - # @testset "initial_params" begin - # test_initial_params(model, sampler_ext; n_adapts=0) - # end - - sample_kwargs = ( - n_adapts=1_000, - discard_initial=1_000, - # FIXME: Remove this once we can run `test_initial_params` above. - initial_params=DynamicPPL.VarInfo(model)[:], - ) - - @testset "inference" begin - DynamicPPL.TestUtils.test_sampler( - [model], - sampler_ext, - 2_000; - rtol=0.2, - sampler_name="AdvancedHMC", - sample_kwargs..., - ) - end end end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index c878f755de..ad36f074d1 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -12,14 +12,14 @@ import Mooncake using Test: @test, @testset using Turing -@testset "Testing sghmc.jl with $adbackend" for adbackend in ADUtils.adbackends +@testset "Testing sghmc.jl" begin @testset "sghmc constructor" begin - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} @@ -27,20 +27,20 @@ using Turing @testset "sghmc inference" begin rng = StableRNG(123) - alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend) + alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=Turing.DEFAULT_ADTYPE) chain = sample(rng, gdemo_default, alg, 10_000) check_gdemo(chain; atol=0.1) end end -@testset "Testing sgld.jl with $adbackend" for adbackend in ADUtils.adbackends +@testset "Testing sgld.jl" begin @testset "sgld constructor" begin - alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) + alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=Turing.DEFAULT_ADTYPE) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} - alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) + alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=Turing.DEFAULT_ADTYPE) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} From df85c6f5122359b2057428eeb9c8e5c0a99fe2c0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 23 May 2025 15:09:20 +0100 Subject: [PATCH 02/18] Add AD tests to sghmc.jl --- test/mcmc/sghmc.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index ad36f074d1..99f423fc54 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -3,6 +3,9 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo import ..ADUtils +using DynamicPPL.TestUtils.AD: run_ad +using DynamicPPL.TestUtils: DEMO_MODELS +using DynamicPPL: DynamicPPL using Distributions: sample import ForwardDiff using LinearAlgebra: dot @@ -12,6 +15,21 @@ import Mooncake using Test: @test, @testset using Turing +@testset "AD with SGHMC / SGLD" begin + @testset "adtype=$adtype" for adtype in ADUtils.adbackends + @testset "alg=$alg" for alg in [ + SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), + SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), + ] + @testset "model=$(model.f)" for model in DEMO_MODELS + rng = StableRNG(123) + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + end + end + end +end + @testset "Testing sghmc.jl" begin @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) From 7369221309b9849546338f595359ef4224d0d8da Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 23 May 2025 16:12:02 +0100 Subject: [PATCH 03/18] Remove AD loops from HMC and SGHMC --- test/Project.toml | 2 +- test/mcmc/hmc.jl | 69 +++++++++++++++++++++++++++++----------------- test/mcmc/sghmc.jl | 7 ++++- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index df0af4c978..cb0874da69 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -52,7 +52,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36" +DynamicPPL = "0.36.6" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 4d884b1637..2dbbaee0e7 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -7,6 +7,8 @@ import ..ADUtils using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample using DynamicPPL: DynamicPPL, Sampler +using DynamicPPL.TestUtils.AD: run_ad +using DynamicPPL.TestUtils: DEMO_MODELS import ForwardDiff using HypothesisTests: ApproximateTwoSampleKSTest, pvalue import ReverseDiff @@ -18,9 +20,41 @@ import Mooncake using Test: @test, @test_logs, @testset, @test_throws using Turing -@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends - @info "Starting HMC tests with $adbackend" +@testset "AD / hmc.jl" begin + # AD tests need to be run with SamplingContext because samplers can potentially + # use this to define custom behaviour in the tilde-pipeline and thus change the + # code executed during model evaluation. + @testset "adtype=$adtype" for adtype in ADUtils.adbackends + @testset "alg=$alg" for alg in [ + HMC(0.1, 10; adtype=adtype), + HMCDA(0.8, 0.75; adtype=adtype), + NUTS(1000, 0.8; adtype=adtype), + ] + @info "Testing AD for $alg" + + @testset "model=$(model.f)" for model in DEMO_MODELS + rng = StableRNG(123) + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + end + end + + @testset "Check ADType" begin + seed = 123 + alg = HMC(0.1, 10; adtype=adtype) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + sample(StableRNG(seed), m, alg, 10) + end + end +end + +@testset "Testing hmc.jl" begin + @info "Starting HMC tests" seed = 123 + adbackend = Turing.DEFAULT_ADTYPE @testset "constrained bounded" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @@ -65,12 +99,6 @@ using Turing check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end - @testset "hmc reverse diff" begin - alg = HMC(0.1, 10; adtype=adbackend) - res = sample(StableRNG(seed), gdemo_default, alg, 4_000) - check_gdemo(res; rtol=0.1) - end - # Test the sampling of a matrix-value distribution. @testset "matrix support" begin dist = Wishart(7, [1 0.5; 0.5 1]) @@ -211,10 +239,12 @@ using Turing end @testset "prior" begin + # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance + # which means that it's _very_ difficult to find a good tolerance in the test below:) + prior_dist = truncated(Normal(3, 1); lower=0) + @model function demo_hmc_prior() - # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance - # which means that it's _very_ difficult to find a good tolerance in the test below:) - s ~ truncated(Normal(3, 1); lower=0) + s ~ prior_dist return m ~ Normal(0, sqrt(s)) end alg = NUTS(1000, 0.8; adtype=adbackend) @@ -222,9 +252,7 @@ using Turing demo_hmc_prior(), DynamicPPL.PriorContext() ) chain = sample(gdemo_default_prior, alg, 5_000; initial_params=[3.0, 0.0]) - check_numerical( - chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2 - ) + check_numerical(chain, [:s, :m], [mean(prior_dist), 0]; atol=0.2) end @testset "warning for difficult init params" begin @@ -292,8 +320,8 @@ using Turing # Extract the `x` like this because running `generated_quantities` was how # the issue was discovered, hence we also want to make sure that it works. - results = generated_quantities(model, chain) - results_prior = generated_quantities(model, chain_prior) + results = returned(model, chain) + results_prior = returned(model, chain_prior) # Make sure none of the samples in the chains resulted in errors. @test all(!isnothing, results) @@ -315,15 +343,6 @@ using Turing @test Turing.Inference.getstepsize(spl, hmc_state) isa Float64 end end - - @testset "Check ADType" begin - alg = HMC(0.1, 10; adtype=adbackend) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(StableRNG(seed), m, alg, 10) - end end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 99f423fc54..71dfef5d39 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -15,12 +15,17 @@ import Mooncake using Test: @test, @testset using Turing -@testset "AD with SGHMC / SGLD" begin +@testset "AD / sghmc.jl" begin + # AD tests need to be run with SamplingContext because samplers can potentially + # use this to define custom behaviour in the tilde-pipeline and thus change the + # code executed during model evaluation. @testset "adtype=$adtype" for adtype in ADUtils.adbackends @testset "alg=$alg" for alg in [ SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), ] + @info "Testing AD for $alg" + @testset "model=$(model.f)" for model in DEMO_MODELS rng = StableRNG(123) ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) From 62d414fed38e9d0a1fea860854ddd171070965c2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 23 May 2025 16:35:51 +0100 Subject: [PATCH 04/18] Handle Inference.jl --- test/mcmc/Inference.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 5966589090..70ff9fa9bf 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -17,9 +17,11 @@ import Mooncake using Test: @test, @test_throws, @testset using Turing -@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends - @info "Starting Inference.jl tests with $adbackend" +@testset "Testing Inference.jl" begin + @info "Starting Inference.jl tests" + seed = 23 + adbackend = Turing.DEFAULT_ADTYPE @testset "threaded sampling" begin # Test that chains with the same seed will sample identically. @@ -44,7 +46,7 @@ using Turing @test chain1.value == chain2.value end - # Should also be stable with am explicit RNG + # Should also be stable with an explicit RNG seed = 5 rng = Random.MersenneTwister(seed) for sampler in samplers @@ -273,17 +275,12 @@ using Turing @testset "forbid global" begin xs = [1.5 2.0] - # xx = 1 @model function fggibbstest(xs) s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) - # xx ~ Normal(m, sqrt(s)) # this is illegal - for i in 1:length(xs) xs[i] ~ Normal(m, sqrt(s)) - # for xx in xs - # xx ~ Normal(m, sqrt(s)) end return s, m end @@ -353,7 +350,7 @@ using Turing ) end - # TODO(mhauru) What is this testing? Why does it not use the looped-over adbackend? + # TODO(mhauru) What is this testing? Why does it use a different adbackend? @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] From a4387a857c577d6d1370b91bfec3fbb2218f93cc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 23 May 2025 18:49:01 +0100 Subject: [PATCH 05/18] Remove AD loop in Gibbs --- test/mcmc/gibbs.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index ec9d42048a..9cecfd3af8 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -384,8 +384,9 @@ end @test wuc.non_warmup_count == (num_samples - 1) * num_reps end -@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @info "Starting Gibbs tests with $adbackend" +@testset "Testing gibbs.jl" begin + @info "Starting Gibbs tests" + adbackend = Turing.DEFAULT_ADTYPE @testset "Gibbs constructors" begin # Create Gibbs samplers with various configurations and ways of passing the @@ -604,12 +605,7 @@ end return m .~ Normal(1.0, 1.0) end model = dynamic_model_with_dot_tilde() - # TODO(mhauru) This is broken because of - # https://github.com/TuringLang/DynamicPPL.jl/issues/700. - @test_broken ( - sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100); - true - ) + sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100) end @testset "Demo models" begin From fd2359c3250c84531bdd099ffb9c31bd54900a6c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 23 May 2025 19:20:35 +0100 Subject: [PATCH 06/18] Fix broken test --- test/mcmc/gibbs.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 9cecfd3af8..0807d1dcef 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -598,7 +598,7 @@ end @model function dynamic_model_with_dot_tilde( num_zs=10, ::Type{M}=Vector{Float64} ) where {M} - z = M(undef, num_zs) + z = Vector{Int}(undef, num_zs) z .~ Poisson(1.0) num_ms = sum(z) m = M(undef, num_ms) @@ -842,7 +842,7 @@ end check_MoGtest_default_z_vector(chain; atol=0.2) end - @testset "externsalsampler" begin + @testset "externalsampler" begin @model function demo_gibbs_external() m1 ~ Normal() m2 ~ Normal() From f257f15dd37f1e7bb13b0efbbbc591c9e8372e4f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 24 May 2025 16:50:55 +0100 Subject: [PATCH 07/18] Mark testsets as verbose=true --- test/mcmc/Inference.jl | 2 +- test/mcmc/abstractmcmc.jl | 2 +- test/mcmc/gibbs.jl | 8 ++++---- test/mcmc/hmc.jl | 2 +- test/mcmc/sghmc.jl | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 70ff9fa9bf..8901ed3606 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -17,7 +17,7 @@ import Mooncake using Test: @test, @test_throws, @testset using Turing -@testset "Testing Inference.jl" begin +@testset verbose = true "Testing Inference.jl" begin @info "Starting Inference.jl tests" seed = 23 diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 5cb8330259..76634200d8 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -110,7 +110,7 @@ function test_initial_params( end end -@testset "External samplers" begin +@testset verbose = true "External samplers" begin @testset "AdvancedHMC.jl" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS adtype = Turing.DEFAULT_ADTYPE diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 0807d1dcef..b118b0142c 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -46,7 +46,7 @@ const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::DynamicPPL.Model) = true -@testset "GibbsContext" begin +@testset verbose = true "GibbsContext" begin @testset "type stability" begin struct Wrapper{T<:Real} a::T @@ -384,7 +384,7 @@ end @test wuc.non_warmup_count == (num_samples - 1) * num_reps end -@testset "Testing gibbs.jl" begin +@testset verbose = true "Testing gibbs.jl" begin @info "Starting Gibbs tests" adbackend = Turing.DEFAULT_ADTYPE @@ -608,8 +608,8 @@ end sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100) end - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "Demo model" begin + @testset verbose = true "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) samplers = [ Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()), diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 2dbbaee0e7..f7c48987ce 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -51,7 +51,7 @@ using Turing end end -@testset "Testing hmc.jl" begin +@testset verbose = true "Testing hmc.jl" begin @info "Starting HMC tests" seed = 123 adbackend = Turing.DEFAULT_ADTYPE diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 71dfef5d39..4af3703857 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -35,7 +35,7 @@ using Turing end end -@testset "Testing sghmc.jl" begin +@testset verbose = true "Testing sghmc.jl" begin @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) @test alg isa SGHMC From be080261162725eb2fac9d89208401da284a785f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 00:07:45 +0100 Subject: [PATCH 08/18] Separate AD tests into their own file --- test/ad.jl | 73 ++++++++++++++++++++++++++++++++++++++++++++++ test/mcmc/hmc.jl | 33 --------------------- test/mcmc/sghmc.jl | 20 ------------- test/runtests.jl | 4 +++ 4 files changed, 77 insertions(+), 53 deletions(-) create mode 100644 test/ad.jl diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 0000000000..afd734b638 --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,73 @@ +module TuringADTests + +using Turing +using DynamicPPL +using DynamicPPL.TestUtils: DEMO_MODELS +using DynamicPPL.TestUtils.AD: run_ad +using StableRNGs: StableRNG +using Test +using ..Models: gdemo_default +using ..ADUtils: ADTypeCheckContext, adbackends + +@testset verbose = true "AD / SamplingContext" begin + # AD tests for gradient-based samplers need to be run with SamplingContext + # because samplers can potentially use this to define custom behaviour in + # the tilde-pipeline and thus change the code executed during model + # evaluation. + @testset "adtype=$adtype" for adtype in adbackends + @testset "alg=$alg" for alg in [ + HMC(0.1, 10; adtype=adtype), + HMCDA(0.8, 0.75; adtype=adtype), + NUTS(1000, 0.8; adtype=adtype), + SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), + SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), + ] + @info "Testing AD for $alg" + + @testset "model=$(model.f)" for model in DEMO_MODELS + rng = StableRNG(123) + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + end + end + + @testset "Check ADType" begin + seed = 123 + alg = HMC(0.1, 10; adtype=adtype) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + sample(StableRNG(seed), m, alg, 10) + end + end +end + +@testset verbose = true "AD / GibbsContext" begin + # Gibbs sampling also needs extra AD testing because the models are + # executed with GibbsContext and a subsetted varinfo. (see e.g. + # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in + # src/mcmc/gibbs.jl -- the code here mimics what happens in those + # functions) + @testset "adtype=$adtype" for adtype in adbackends + @testset "model=$(model.f)" for model in DEMO_MODELS + # All the demo models have variables `s` and `m`, so we'll pretend + # that we're using a Gibbs sampler where both of them are sampled + # with a gradient-based sampler (say HMC(0.1, 10)). + # This means we need to construct one with only `s`, and one model with + # only `m`. + global_vi = DynamicPPL.VarInfo(model) + @testset for varnames in ([@varname(s)], [@varname(m)]) + @info "Testing Gibbs AD with model=$(model.f), varnames=$varnames" + conditioned_model = Turing.Inference.make_conditional( + model, varnames, deepcopy(global_vi) + ) + rng = StableRNG(123) + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + end + end + end +end + +end # module diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index f7c48987ce..f3c46e9e75 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -7,8 +7,6 @@ import ..ADUtils using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample using DynamicPPL: DynamicPPL, Sampler -using DynamicPPL.TestUtils.AD: run_ad -using DynamicPPL.TestUtils: DEMO_MODELS import ForwardDiff using HypothesisTests: ApproximateTwoSampleKSTest, pvalue import ReverseDiff @@ -20,37 +18,6 @@ import Mooncake using Test: @test, @test_logs, @testset, @test_throws using Turing -@testset "AD / hmc.jl" begin - # AD tests need to be run with SamplingContext because samplers can potentially - # use this to define custom behaviour in the tilde-pipeline and thus change the - # code executed during model evaluation. - @testset "adtype=$adtype" for adtype in ADUtils.adbackends - @testset "alg=$alg" for alg in [ - HMC(0.1, 10; adtype=adtype), - HMCDA(0.8, 0.75; adtype=adtype), - NUTS(1000, 0.8; adtype=adtype), - ] - @info "Testing AD for $alg" - - @testset "model=$(model.f)" for model in DEMO_MODELS - rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any - end - end - - @testset "Check ADType" begin - seed = 123 - alg = HMC(0.1, 10; adtype=adtype) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(StableRNG(seed), m, alg, 10) - end - end -end - @testset verbose = true "Testing hmc.jl" begin @info "Starting HMC tests" seed = 123 diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 4af3703857..f12943cc24 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -15,26 +15,6 @@ import Mooncake using Test: @test, @testset using Turing -@testset "AD / sghmc.jl" begin - # AD tests need to be run with SamplingContext because samplers can potentially - # use this to define custom behaviour in the tilde-pipeline and thus change the - # code executed during model evaluation. - @testset "adtype=$adtype" for adtype in ADUtils.adbackends - @testset "alg=$alg" for alg in [ - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), - ] - @info "Testing AD for $alg" - - @testset "model=$(model.f)" for model in DEMO_MODELS - rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any - end - end - end -end - @testset verbose = true "Testing sghmc.jl" begin @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) diff --git a/test/runtests.jl b/test/runtests.jl index 47b714188e..4832dc110f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,10 @@ end @timeit_include("Aqua.jl") end + @testset "AD" verbose = true begin + @timeit_include("ad.jl") + end + @testset "essential" verbose = true begin @timeit_include("essential/container.jl") end From 0d2955fbb9cbb495e9f8eb1eb659f8f92b580f28 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 01:05:28 +0100 Subject: [PATCH 09/18] Finish AD tests --- test/Project.toml | 1 + test/ad.jl | 202 +++++++++++++++++++++++++++++++++--- test/runtests.jl | 109 ++++++++++--------- test/test_utils/ad_utils.jl | 193 ---------------------------------- 4 files changed, 244 insertions(+), 261 deletions(-) delete mode 100644 test/test_utils/ad_utils.jl diff --git a/test/Project.toml b/test/Project.toml index 7304752a17..d6ed2c79fd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" diff --git a/test/ad.jl b/test/ad.jl index afd734b638..a390c21ca5 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -7,14 +7,200 @@ using DynamicPPL.TestUtils.AD: run_ad using StableRNGs: StableRNG using Test using ..Models: gdemo_default -using ..ADUtils: ADTypeCheckContext, adbackends + +"""Element types that are always valid for a VarInfo regardless of ADType.""" +const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) + +"""A dictionary mapping ADTypes to the element types they use.""" +const eltypes_by_adtype = Dict( + Turing.AutoForwardDiff => (ForwardDiff.Dual,), + Turing.AutoReverseDiff => ( + ReverseDiff.TrackedArray, + ReverseDiff.TrackedMatrix, + ReverseDiff.TrackedReal, + ReverseDiff.TrackedStyle, + ReverseDiff.TrackedType, + ReverseDiff.TrackedVecOrMat, + ReverseDiff.TrackedVector, + ), + Turing.AutoMooncake => (Mooncake.CoDual,), +) + +""" + AbstractWrongADBackendError + +An abstract error thrown when we seem to be using a different AD backend than expected. +""" +abstract type AbstractWrongADBackendError <: Exception end + +""" + WrongADBackendError + +An error thrown when we seem to be using a different AD backend than expected. +""" +struct WrongADBackendError <: AbstractWrongADBackendError + actual_adtype::Type + expected_adtype::Type +end + +function Base.showerror(io::IO, e::WrongADBackendError) + return print( + io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead." + ) +end + +""" + IncompatibleADTypeError + +An error thrown when an element type is encountered that is unexpected for the given ADType. +""" +struct IncompatibleADTypeError <: AbstractWrongADBackendError + valtype::Type + adtype::Type +end + +function Base.showerror(io::IO, e::IncompatibleADTypeError) + return print( + io, + "Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)", + ) +end + +""" + ADTypeCheckContext{ADType,ChildContext} + +A context for checking that the expected ADType is being used. + +Evaluating a model with this context will check that the types of values in a `VarInfo` are +compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError` +is thrown. + +For instance, evaluating a model with +`ADTypeCheckContext(AutoForwardDiff(), child_context)` +would throw an error if within the model a type associated with e.g. ReverseDiff was +encountered. + +""" +struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: + DynamicPPL.AbstractContext + child::ChildContext + + function ADTypeCheckContext(adbackend, child) + adtype = adbackend isa Type ? adbackend : typeof(adbackend) + if !any(adtype <: k for k in keys(eltypes_by_adtype)) + throw(ArgumentError("Unsupported ADType: $adtype")) + end + return new{adtype,typeof(child)}(child) + end +end + +adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType + +DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child +function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) + return ADTypeCheckContext(adtype(c), child) +end + +""" + valid_eltypes(context::ADTypeCheckContext) + +Return the element types that are valid for the ADType of `context` as a tuple. +""" +function valid_eltypes(context::ADTypeCheckContext) + context_at = adtype(context) + for at in keys(eltypes_by_adtype) + if context_at <: at + return (eltypes_by_adtype[at]..., always_valid_eltypes...) + end + end + # This should never be reached due to the check in the inner constructor. + throw(ArgumentError("Unsupported ADType: $(adtype(context))")) +end + +""" + check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo) + +Check that the element types in `vi` are compatible with the ADType of `context`. + +Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. +""" +function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) + valids = valid_eltypes(context) + for val in vi[:] + valtype = typeof(val) + if !any(valtype .<: valids) + throw(IncompatibleADTypeError(valtype, adtype(context))) + end + end + return nothing +end + +# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child +# context, and then call check_adtype on the result before returning the results from the +# child context. + +function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) + value, logp, vi = DynamicPPL.tilde_assume( + DynamicPPL.childcontext(context), right, vn, vi + ) + check_adtype(context, vi) + return value, logp, vi +end + +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi +) + value, logp, vi = DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) + check_adtype(context, vi) + return value, logp, vi +end + +function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) + logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) + check_adtype(context, vi) + return logp, vi +end + +function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) + logp, vi = DynamicPPL.tilde_observe( + DynamicPPL.childcontext(context), sampler, right, left, vi + ) + check_adtype(context, vi) + return logp, vi +end + +""" +All the ADTypes on which we want to run the tests. +""" +ADTYPES = [ + Turing.AutoForwardDiff(), + Turing.AutoReverseDiff(; compile=false), + Turing.AutoMooncake(; config=nothing), +] + +@testset verbose = true "AD / ADTypeCheckContext" begin + # This testset ensures that samplers don't accidentally override the AD + # backend set in it. + @testset "Check ADType" begin + seed = 123 + alg = HMC(0.1, 10; adtype=adtype) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + sample(StableRNG(seed), m, alg, 10) + end +end @testset verbose = true "AD / SamplingContext" begin # AD tests for gradient-based samplers need to be run with SamplingContext # because samplers can potentially use this to define custom behaviour in # the tilde-pipeline and thus change the code executed during model # evaluation. - @testset "adtype=$adtype" for adtype in adbackends + @testset "adtype=$adtype" for adtype in ADTYPES @testset "alg=$alg" for alg in [ HMC(0.1, 10; adtype=adtype), HMCDA(0.8, 0.75; adtype=adtype), @@ -30,16 +216,6 @@ using ..ADUtils: ADTypeCheckContext, adbackends @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any end end - - @testset "Check ADType" begin - seed = 123 - alg = HMC(0.1, 10; adtype=adtype) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(StableRNG(seed), m, alg, 10) - end end end @@ -49,7 +225,7 @@ end # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in # src/mcmc/gibbs.jl -- the code here mimics what happens in those # functions) - @testset "adtype=$adtype" for adtype in adbackends + @testset "adtype=$adtype" for adtype in ADTYPES @testset "model=$(model.f)" for model in DEMO_MODELS # All the demo models have variables `s` and `m`, so we'll pretend # that we're using a Gibbs sampler where both of them are sampled diff --git a/test/runtests.jl b/test/runtests.jl index 4832dc110f..e644b6dd4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,9 +9,8 @@ import Turing # Fix the global Random.seed for reproducibility. seed!(23) -include(pkgdir(Turing) * "/test/test_utils/models.jl") -include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl") -include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl") +include("test_utils/models.jl") +include("test_utils/numerical_tests.jl") Turing.setprogress!(false) included_paths, excluded_paths = parse_args(ARGS) @@ -30,63 +29,63 @@ macro timeit_include(path::AbstractString) end @testset "Turing" verbose = true begin - @testset "Test utils" begin - @timeit_include("test_utils/test_utils.jl") - end - - @testset "Aqua" begin - @timeit_include("Aqua.jl") - end + # @testset "Test utils" begin + # @timeit_include("test_utils/test_utils.jl") + # end + # + # @testset "Aqua" begin + # @timeit_include("Aqua.jl") + # end @testset "AD" verbose = true begin @timeit_include("ad.jl") end - @testset "essential" verbose = true begin - @timeit_include("essential/container.jl") - end - - @testset "samplers (without AD)" verbose = true begin - @timeit_include("mcmc/particle_mcmc.jl") - @timeit_include("mcmc/emcee.jl") - @timeit_include("mcmc/ess.jl") - @timeit_include("mcmc/is.jl") - end - - @timeit TIMEROUTPUT "inference" begin - @testset "inference with samplers" verbose = true begin - @timeit_include("mcmc/gibbs.jl") - @timeit_include("mcmc/hmc.jl") - @timeit_include("mcmc/Inference.jl") - @timeit_include("mcmc/sghmc.jl") - @timeit_include("mcmc/abstractmcmc.jl") - @timeit_include("mcmc/mh.jl") - @timeit_include("ext/dynamichmc.jl") - @timeit_include("mcmc/repeat_sampler.jl") - end - - @testset "variational algorithms" begin - @timeit_include("variational/advi.jl") - end - - @testset "mode estimation" verbose = true begin - @timeit_include("optimisation/Optimisation.jl") - @timeit_include("ext/OptimInterface.jl") - end - end - - @testset "variational optimisers" begin - @timeit_include("variational/optimisers.jl") - end - - @testset "stdlib" verbose = true begin - @timeit_include("stdlib/distributions.jl") - @timeit_include("stdlib/RandomMeasures.jl") - end - - @testset "utilities" begin - @timeit_include("mcmc/utilities.jl") - end + # @testset "essential" verbose = true begin + # @timeit_include("essential/container.jl") + # end + # + # @testset "samplers (without AD)" verbose = true begin + # @timeit_include("mcmc/particle_mcmc.jl") + # @timeit_include("mcmc/emcee.jl") + # @timeit_include("mcmc/ess.jl") + # @timeit_include("mcmc/is.jl") + # end + # + # @timeit TIMEROUTPUT "inference" begin + # @testset "inference with samplers" verbose = true begin + # @timeit_include("mcmc/gibbs.jl") + # @timeit_include("mcmc/hmc.jl") + # @timeit_include("mcmc/Inference.jl") + # @timeit_include("mcmc/sghmc.jl") + # @timeit_include("mcmc/abstractmcmc.jl") + # @timeit_include("mcmc/mh.jl") + # @timeit_include("ext/dynamichmc.jl") + # @timeit_include("mcmc/repeat_sampler.jl") + # end + # + # @testset "variational algorithms" begin + # @timeit_include("variational/advi.jl") + # end + # + # @testset "mode estimation" verbose = true begin + # @timeit_include("optimisation/Optimisation.jl") + # @timeit_include("ext/OptimInterface.jl") + # end + # end + # + # @testset "variational optimisers" begin + # @timeit_include("variational/optimisers.jl") + # end + # + # @testset "stdlib" verbose = true begin + # @timeit_include("stdlib/distributions.jl") + # @timeit_include("stdlib/RandomMeasures.jl") + # end + # + # @testset "utilities" begin + # @timeit_include("mcmc/utilities.jl") + # end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl deleted file mode 100644 index 309276407a..0000000000 --- a/test/test_utils/ad_utils.jl +++ /dev/null @@ -1,193 +0,0 @@ -module ADUtils - -using ForwardDiff: ForwardDiff -using Pkg: Pkg -using Random: Random -using ReverseDiff: ReverseDiff -using Mooncake: Mooncake -using Test: Test -using Turing: Turing -using Turing: DynamicPPL - -export ADTypeCheckContext, adbackends - -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -# Stuff for checking that the right AD backend is being used. - -"""Element types that are always valid for a VarInfo regardless of ADType.""" -const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) - -"""A dictionary mapping ADTypes to the element types they use.""" -const eltypes_by_adtype = Dict( - Turing.AutoForwardDiff => (ForwardDiff.Dual,), - Turing.AutoReverseDiff => ( - ReverseDiff.TrackedArray, - ReverseDiff.TrackedMatrix, - ReverseDiff.TrackedReal, - ReverseDiff.TrackedStyle, - ReverseDiff.TrackedType, - ReverseDiff.TrackedVecOrMat, - ReverseDiff.TrackedVector, - ), - Turing.AutoMooncake => (Mooncake.CoDual,), -) - -""" - AbstractWrongADBackendError - -An abstract error thrown when we seem to be using a different AD backend than expected. -""" -abstract type AbstractWrongADBackendError <: Exception end - -""" - WrongADBackendError - -An error thrown when we seem to be using a different AD backend than expected. -""" -struct WrongADBackendError <: AbstractWrongADBackendError - actual_adtype::Type - expected_adtype::Type -end - -function Base.showerror(io::IO, e::WrongADBackendError) - return print( - io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead." - ) -end - -""" - IncompatibleADTypeError - -An error thrown when an element type is encountered that is unexpected for the given ADType. -""" -struct IncompatibleADTypeError <: AbstractWrongADBackendError - valtype::Type - adtype::Type -end - -function Base.showerror(io::IO, e::IncompatibleADTypeError) - return print( - io, - "Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)", - ) -end - -""" - ADTypeCheckContext{ADType,ChildContext} - -A context for checking that the expected ADType is being used. - -Evaluating a model with this context will check that the types of values in a `VarInfo` are -compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError` -is thrown. - -For instance, evaluating a model with -`ADTypeCheckContext(AutoForwardDiff(), child_context)` -would throw an error if within the model a type associated with e.g. ReverseDiff was -encountered. - -""" -struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: - DynamicPPL.AbstractContext - child::ChildContext - - function ADTypeCheckContext(adbackend, child) - adtype = adbackend isa Type ? adbackend : typeof(adbackend) - if !any(adtype <: k for k in keys(eltypes_by_adtype)) - throw(ArgumentError("Unsupported ADType: $adtype")) - end - return new{adtype,typeof(child)}(child) - end -end - -adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType - -DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child -function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) - return ADTypeCheckContext(adtype(c), child) -end - -""" - valid_eltypes(context::ADTypeCheckContext) - -Return the element types that are valid for the ADType of `context` as a tuple. -""" -function valid_eltypes(context::ADTypeCheckContext) - context_at = adtype(context) - for at in keys(eltypes_by_adtype) - if context_at <: at - return (eltypes_by_adtype[at]..., always_valid_eltypes...) - end - end - # This should never be reached due to the check in the inner constructor. - throw(ArgumentError("Unsupported ADType: $(adtype(context))")) -end - -""" - check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo) - -Check that the element types in `vi` are compatible with the ADType of `context`. - -Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. -""" -function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) - valids = valid_eltypes(context) - for val in vi[:] - valtype = typeof(val) - if !any(valtype .<: valids) - throw(IncompatibleADTypeError(valtype, adtype(context))) - end - end - return nothing -end - -# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child -# context, and then call check_adtype on the result before returning the results from the -# child context. - -function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) - check_adtype(context, vi) - return value, logp, vi -end - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi -) - value, logp, vi = DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) - check_adtype(context, vi) - return value, logp, vi -end - -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) - check_adtype(context, vi) - return logp, vi -end - -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe( - DynamicPPL.childcontext(context), sampler, right, left, vi - ) - check_adtype(context, vi) - return logp, vi -end - -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -# List of AD backends to test. - -""" -All the ADTypes on which we want to run the tests. -""" -adbackends = [ - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(; compile=false), - Turing.AutoMooncake(; config=nothing), -] - -end From 5d654c79b788082c529ee18848af27d2b0422dcd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 01:20:38 +0100 Subject: [PATCH 10/18] No I did not mean to comment out all other tests --- test/runtests.jl | 104 +++++++++++++++++++++++------------------------ 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e644b6dd4e..493a4eb2f6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,63 +29,63 @@ macro timeit_include(path::AbstractString) end @testset "Turing" verbose = true begin - # @testset "Test utils" begin - # @timeit_include("test_utils/test_utils.jl") - # end - # - # @testset "Aqua" begin - # @timeit_include("Aqua.jl") - # end + @testset "Test utils" begin + @timeit_include("test_utils/test_utils.jl") + end + + @testset "Aqua" begin + @timeit_include("Aqua.jl") + end @testset "AD" verbose = true begin @timeit_include("ad.jl") end - # @testset "essential" verbose = true begin - # @timeit_include("essential/container.jl") - # end - # - # @testset "samplers (without AD)" verbose = true begin - # @timeit_include("mcmc/particle_mcmc.jl") - # @timeit_include("mcmc/emcee.jl") - # @timeit_include("mcmc/ess.jl") - # @timeit_include("mcmc/is.jl") - # end - # - # @timeit TIMEROUTPUT "inference" begin - # @testset "inference with samplers" verbose = true begin - # @timeit_include("mcmc/gibbs.jl") - # @timeit_include("mcmc/hmc.jl") - # @timeit_include("mcmc/Inference.jl") - # @timeit_include("mcmc/sghmc.jl") - # @timeit_include("mcmc/abstractmcmc.jl") - # @timeit_include("mcmc/mh.jl") - # @timeit_include("ext/dynamichmc.jl") - # @timeit_include("mcmc/repeat_sampler.jl") - # end - # - # @testset "variational algorithms" begin - # @timeit_include("variational/advi.jl") - # end - # - # @testset "mode estimation" verbose = true begin - # @timeit_include("optimisation/Optimisation.jl") - # @timeit_include("ext/OptimInterface.jl") - # end - # end - # - # @testset "variational optimisers" begin - # @timeit_include("variational/optimisers.jl") - # end - # - # @testset "stdlib" verbose = true begin - # @timeit_include("stdlib/distributions.jl") - # @timeit_include("stdlib/RandomMeasures.jl") - # end - # - # @testset "utilities" begin - # @timeit_include("mcmc/utilities.jl") - # end + @testset "essential" verbose = true begin + @timeit_include("essential/container.jl") + end + + @testset "samplers (without AD)" verbose = true begin + @timeit_include("mcmc/particle_mcmc.jl") + @timeit_include("mcmc/emcee.jl") + @timeit_include("mcmc/ess.jl") + @timeit_include("mcmc/is.jl") + end + + @timeit TIMEROUTPUT "inference" begin + @testset "inference with samplers" verbose = true begin + @timeit_include("mcmc/gibbs.jl") + @timeit_include("mcmc/hmc.jl") + @timeit_include("mcmc/Inference.jl") + @timeit_include("mcmc/sghmc.jl") + @timeit_include("mcmc/abstractmcmc.jl") + @timeit_include("mcmc/mh.jl") + @timeit_include("ext/dynamichmc.jl") + @timeit_include("mcmc/repeat_sampler.jl") + end + + @testset "variational algorithms" begin + @timeit_include("variational/advi.jl") + end + + @testset "mode estimation" verbose = true begin + @timeit_include("optimisation/Optimisation.jl") + @timeit_include("ext/OptimInterface.jl") + end + end + + @testset "variational optimisers" begin + @timeit_include("variational/optimisers.jl") + end + + @testset "stdlib" verbose = true begin + @timeit_include("stdlib/distributions.jl") + @timeit_include("stdlib/RandomMeasures.jl") + end + + @testset "utilities" begin + @timeit_include("mcmc/utilities.jl") + end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) From f59a89c326a9b05c2ed3ed2fb7a6dbe2ccdc5c58 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 01:25:39 +0100 Subject: [PATCH 11/18] Move ADTypeCheckContexts and optimisation tests to ad.jl as well --- test/ad.jl | 40 ++++++++++++++++++++++++++++--- test/mcmc/Inference.jl | 1 - test/mcmc/abstractmcmc.jl | 1 - test/mcmc/gibbs.jl | 1 - test/mcmc/hmc.jl | 2 -- test/mcmc/sghmc.jl | 1 - test/optimisation/Optimisation.jl | 11 --------- test/test_utils/test_utils.jl | 30 ----------------------- 8 files changed, 37 insertions(+), 50 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index a390c21ca5..a88d23ae4e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -4,9 +4,11 @@ using Turing using DynamicPPL using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL.TestUtils.AD: run_ad +using Random: Random using StableRNGs: StableRNG using Test using ..Models: gdemo_default +import ForwardDiff, ReverseDiff, Mooncake """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) @@ -181,10 +183,40 @@ ADTYPES = [ Turing.AutoMooncake(; config=nothing), ] +# Check that ADTypeCheckContext itself works as expected. +@testset "ADTypeCheckContext" begin + @model test_model() = x ~ Normal(0, 1) + tm = test_model() + adtypes = ( + Turing.AutoForwardDiff(), + Turing.AutoReverseDiff(), + # TODO: Mooncake + # Turing.AutoMooncake(config=nothing), + ) + for actual_adtype in adtypes + sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) + for expected_adtype in adtypes + contextualised_tm = DynamicPPL.contextualize( + tm, ADTypeCheckContext(expected_adtype, tm.context) + ) + @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin + if actual_adtype == expected_adtype + # Check that this does not throw an error. + Turing.sample(contextualised_tm, sampler, 2) + else + @test_throws AbstractWrongADBackendError Turing.sample( + contextualised_tm, sampler, 2 + ) + end + end + end + end +end + @testset verbose = true "AD / ADTypeCheckContext" begin - # This testset ensures that samplers don't accidentally override the AD - # backend set in it. - @testset "Check ADType" begin + # This testset ensures that samplers or optimisers don't accidentally + # override the AD backend set in it. + @testset "adtype=$adtype" for adtype in ADTYPES seed = 123 alg = HMC(0.1, 10; adtype=adtype) m = DynamicPPL.contextualize( @@ -192,6 +224,8 @@ ADTYPES = [ ) # These will error if the adbackend being used is not the one set. sample(StableRNG(seed), m, alg, 10) + maximum_likelihood(m; adtype=adtype) + maximum_a_posteriori(m; adtype=adtype) end end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 737c8ba714..736af9c393 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -2,7 +2,6 @@ module InferenceTests using ..Models: gdemo_d, gdemo_default using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 76634200d8..7651752ba6 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -1,6 +1,5 @@ module AbstractMCMCTests -import ..ADUtils using AbstractMCMC: AbstractMCMC using AdvancedMH: AdvancedMH using Distributions: sample diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index b118b0142c..93630c9fd1 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -7,7 +7,6 @@ using ..NumericalTests: check_gdemo, check_numerical, two_sample_test -import ..ADUtils import Combinatorics using AbstractMCMC: AbstractMCMC using Distributions: InverseGamma, Normal diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index f3c46e9e75..fd588b8f0a 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -1,9 +1,7 @@ module HMCTests using ..Models: gdemo_default -using ..ADUtils: ADTypeCheckContext using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample using DynamicPPL: DynamicPPL, Sampler diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index f12943cc24..cea361c432 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -2,7 +2,6 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo -import ..ADUtils using DynamicPPL.TestUtils.AD: run_ad using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL: DynamicPPL diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 9894d621ce..cf692cba25 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -1,7 +1,6 @@ module OptimisationTests using ..Models: gdemo, gdemo_default -using ..ADUtils: ADUtils using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL @@ -624,16 +623,6 @@ using Turing @assert get(result, :c) == (; :c => Array{Float64}[]) end - @testset "ADType test with $adbackend" for adbackend in ADUtils.adbackends - Random.seed!(222) - m = DynamicPPL.contextualize( - gdemo_default, ADUtils.ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - maximum_likelihood(m; adtype=adbackend) - maximum_a_posteriori(m; adtype=adbackend) - end - @testset "Collinear coeftable" begin xs = [-1.0, 0.0, 1.0] ys = [0.0, 0.0, 0.0] diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl index 243a80b881..b140bf7a78 100644 --- a/test/test_utils/test_utils.jl +++ b/test/test_utils/test_utils.jl @@ -8,34 +8,4 @@ using Test: @test, @testset, @test_throws using Turing: Turing using Turing: DynamicPPL -# Check that the ADTypeCheckContext works as expected. -@testset "ADTypeCheckContext" begin - Turing.@model test_model() = x ~ Turing.Normal(0, 1) - tm = test_model() - adtypes = ( - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), - ) - for actual_adtype in adtypes - sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) - for expected_adtype in adtypes - contextualised_tm = DynamicPPL.contextualize( - tm, ADTypeCheckContext(expected_adtype, tm.context) - ) - @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin - if actual_adtype == expected_adtype - # Check that this does not throw an error. - Turing.sample(contextualised_tm, sampler, 2) - else - @test_throws AbstractWrongADBackendError Turing.sample( - contextualised_tm, sampler, 2 - ) - end - end - end - end -end - end From 0c8b28ff1c03736d7b9f07f5c4d97ccddae4477b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 01:45:18 +0100 Subject: [PATCH 12/18] Remove stray test_utils file --- test/runtests.jl | 4 ---- test/test_utils/test_utils.jl | 11 ----------- 2 files changed, 15 deletions(-) delete mode 100644 test/test_utils/test_utils.jl diff --git a/test/runtests.jl b/test/runtests.jl index 493a4eb2f6..69f8045dd8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,10 +29,6 @@ macro timeit_include(path::AbstractString) end @testset "Turing" verbose = true begin - @testset "Test utils" begin - @timeit_include("test_utils/test_utils.jl") - end - @testset "Aqua" begin @timeit_include("Aqua.jl") end diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl deleted file mode 100644 index b140bf7a78..0000000000 --- a/test/test_utils/test_utils.jl +++ /dev/null @@ -1,11 +0,0 @@ -"""Module for testing the test utils themselves.""" -module TestUtilsTests - -using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff -using Test: @test, @testset, @test_throws -using Turing: Turing -using Turing: DynamicPPL - -end From 4210f120c90879476f461b64ce2d99bc257471e9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 15:53:05 +0100 Subject: [PATCH 13/18] Remove has_dot_assume check --- test/mcmc/gibbs.jl | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 93630c9fd1..5e4ef2149b 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -33,18 +33,6 @@ function check_transition_varnames(transition::Turing.Inference.Transition, pare end end -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_observe_matrix_index)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::DynamicPPL.Model) = true - @testset verbose = true "GibbsContext" begin @testset "type stability" begin struct Wrapper{T<:Real} @@ -614,19 +602,10 @@ end Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()), Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)), Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => ESS()), + Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()), + Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)), ] - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()), - Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)), - ], - ) - end - @testset "$sampler" for sampler in samplers # Check that taking steps performs as expected. rng = Random.default_rng() From e86938e00ffca45119d47c4cf9c65388cd8d392c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 15:53:34 +0100 Subject: [PATCH 14/18] Remove TODO comment about Mooncake --- test/ad.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index a88d23ae4e..8b75c95004 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -190,8 +190,7 @@ ADTYPES = [ adtypes = ( Turing.AutoForwardDiff(), Turing.AutoReverseDiff(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), + # Don't need to test Mooncake as it doesn't use tracer types ) for actual_adtype in adtypes sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) From 2b2f352fa920e300d0fc0ca977877821b312a92a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 15:57:30 +0100 Subject: [PATCH 15/18] Test Gibbs sampling --- test/ad.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/ad.jl b/test/ad.jl index 8b75c95004..157a066286 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -279,4 +279,14 @@ end end end +@testset verbose = true "AD / Gibbs sampling" begin + # Make sure that Gibbs sampling doesn't fall over when using AD. + spl = Gibbs(@varname(s) => HMC(0.1, 10), @varname(m) => HMC(0.1, 10)) + @testset "adtype=$adtype" for adtype in ADTYPES + @testset "model=$(model.f)" for model in DEMO_MODELS + @test sample(model, spl, 2) isa Any + end + end +end + end # module From 5e7ed7dcfa7ade904b100c475a792b8d618cbe6c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 28 May 2025 16:55:16 +0100 Subject: [PATCH 16/18] Fix test --- test/ad.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/ad.jl b/test/ad.jl index 157a066286..30cfc13d7a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -281,8 +281,11 @@ end @testset verbose = true "AD / Gibbs sampling" begin # Make sure that Gibbs sampling doesn't fall over when using AD. - spl = Gibbs(@varname(s) => HMC(0.1, 10), @varname(m) => HMC(0.1, 10)) @testset "adtype=$adtype" for adtype in ADTYPES + spl = Gibbs( + @varname(s) => HMC(0.1, 10; adtype=adtype), + @varname(m) => HMC(0.1, 10; adtype=adtype), + ) @testset "model=$(model.f)" for model in DEMO_MODELS @test sample(model, spl, 2) isa Any end From c0be80c463a485f69e19c9a35c75cfb7d80b9c25 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 29 May 2025 13:35:28 +0100 Subject: [PATCH 17/18] Remove unnecessary DEFAULT_ADTYPEs --- test/mcmc/Inference.jl | 42 +++++++++++++++++------------------------- test/mcmc/gibbs.jl | 32 ++++++++++++++------------------ test/mcmc/hmc.jl | 42 ++++++++++++++++++------------------------ test/mcmc/sghmc.jl | 12 ++++++------ 4 files changed, 55 insertions(+), 73 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 736af9c393..78d79a5480 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -20,7 +20,6 @@ using Turing @info "Starting Inference.jl tests" seed = 23 - adbackend = Turing.DEFAULT_ADTYPE @testset "threaded sampling" begin # Test that chains with the same seed will sample identically. @@ -28,12 +27,12 @@ using Turing model = gdemo_default samplers = ( - HMC(0.1, 7; adtype=adbackend), + HMC(0.1, 7), PG(10), IS(), MH(), - Gibbs(:s => PG(3), :m => HMC(0.4, 8; adtype=adbackend)), - Gibbs(:s => HMC(0.1, 5; adtype=adbackend), :m => ESS()), + Gibbs(:s => PG(3), :m => HMC(0.4, 8)), + Gibbs(:s => HMC(0.1, 5), :m => ESS()), ) for sampler in samplers Random.seed!(5) @@ -62,27 +61,22 @@ using Turing # Smoke test for default sample call. @testset "gdemo_default" begin chain = sample( - StableRNG(seed), - gdemo_default, - HMC(0.1, 7; adtype=adbackend), - MCMCThreads(), - 1_000, - 4, + StableRNG(seed), gdemo_default, HMC(0.1, 7), MCMCThreads(), 1_000, 4 ) check_gdemo(chain) # run sampler: progress logging should be disabled and # it should return a Chains object - sampler = Sampler(HMC(0.1, 7; adtype=adbackend)) + sampler = Sampler(HMC(0.1, 7)) chains = sample(StableRNG(seed), gdemo_default, sampler, MCMCThreads(), 10, 4) @test chains isa MCMCChains.Chains end end @testset "chain save/resume" begin - alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) + alg1 = HMCDA(1000, 0.65, 0.15) alg2 = PG(20) - alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4; adtype=adbackend)) + alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4)) chn1 = sample(StableRNG(seed), gdemo_default, alg1, 10_000; save_state=true) check_gdemo(chn1) @@ -261,7 +255,7 @@ using Turing smc = SMC() pg = PG(10) - gibbs = Gibbs(:p => HMC(0.2, 3; adtype=adbackend), :x => PG(10)) + gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10)) chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) @@ -284,7 +278,7 @@ using Turing return s, m end - gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8; adtype=adbackend)) + gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8)) chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2) end @@ -378,9 +372,7 @@ using Turing end end - chain = sample( - StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10; adtype=adbackend), 4000 - ) + chain = sample(StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10), 4000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]) end @@ -411,7 +403,7 @@ using Turing end @testset "sample" begin - alg = Gibbs(:m => HMC(0.2, 3; adtype=adbackend), :s => PG(10)) + alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10)) chn = sample(StableRNG(seed), gdemo_default, alg, 10) end @@ -423,7 +415,7 @@ using Turing return s, m end - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) x = randn(100) res = sample(StableRNG(seed), vdemo1(x), alg, 10) @@ -438,7 +430,7 @@ using Turing # Vector assumptions N = 10 - alg = HMC(0.2, 4; adtype=adbackend) + alg = HMC(0.2, 4) @model function vdemo3() x = Vector{Real}(undef, N) @@ -493,7 +485,7 @@ using Turing return s, m end - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) x = randn(100) res = sample(StableRNG(seed), vdemo1(x), alg, 10) @@ -503,12 +495,12 @@ using Turing end D = 2 - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10) # Vector assumptions N = 10 - alg = HMC(0.2, 4; adtype=adbackend) + alg = HMC(0.2, 4) @model function vdemo3() x = Vector{Real}(undef, N) @@ -555,7 +547,7 @@ using Turing @testset "Type parameters" begin N = 10 - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) x = randn(1000) @model function vdemo1(::Type{T}=Float64) where {T} x = Vector{T}(undef, N) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 5e4ef2149b..2070c697c4 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -373,22 +373,21 @@ end @testset verbose = true "Testing gibbs.jl" begin @info "Starting Gibbs tests" - adbackend = Turing.DEFAULT_ADTYPE @testset "Gibbs constructors" begin # Create Gibbs samplers with various configurations and ways of passing the # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. N = 10 # Two variables being sampled by one sampler. - s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5)) s2 = Gibbs((@varname(s), :m) => PG(10)) # As above but different samplers and using kwargs. - s3 = Gibbs(:s => CSMC(3), :m => HMCDA(200, 0.65, 0.15; adtype=adbackend)) - s4 = Gibbs(@varname(s) => HMC(0.1, 5; adtype=adbackend), @varname(m) => ESS()) + s3 = Gibbs(:s => CSMC(3), :m => HMCDA(200, 0.65, 0.15)) + s4 = Gibbs(@varname(s) => HMC(0.1, 5), @varname(m) => ESS()) # Multiple instnaces of the same sampler. This implements running, in this case, # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. s5 = begin - hmc = HMC(0.1, 5; adtype=adbackend) + hmc = HMC(0.1, 5) pg = PG(10) vns = @varname(s) vnm = @varname(m) @@ -396,7 +395,7 @@ end end # Same thing but using RepeatSampler. s6 = Gibbs( - @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), + @varname(s) => RepeatSampler(HMC(0.1, 5), 3), @varname(m) => RepeatSampler(PG(10), 2), ) for s in (s1, s2, s3, s4, s5, s6) @@ -418,7 +417,7 @@ end # posterior mean. @testset "Gibbs inference" begin @testset "CSMC and HMC on gdemo" begin - alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -426,7 +425,7 @@ end end @testset "MH and HMCDA on gdemo" begin - alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @@ -447,7 +446,7 @@ end @testset "PG and HMC on MoGtest_default" begin gibbs = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3), ) chain = sample(MoGtest_default, gibbs, 2_000) check_MoGtest_default(chain; atol=0.15) @@ -460,8 +459,8 @@ end (@varname(s), @varname(m)) => MH(), @varname(m) => ESS(), @varname(s) => RepeatSampler(MH(), 3), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), + @varname(m) => HMC(0.2, 4), + (@varname(m), @varname(s)) => HMC(0.2, 4), ) chain = sample(gdemo(1.5, 2.0), alg, 500) check_gdemo(chain; atol=0.15) @@ -471,7 +470,7 @@ end gibbs = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), (@varname(z1), @varname(z2)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3), (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), (@varname(mu1)) => ESS(), (@varname(mu2)) => ESS(), @@ -509,7 +508,7 @@ end return nothing end - alg = Gibbs(:s => MH(), :m => HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMC(0.2, 4)) sample(model, alg, 100; callback=callback) end @@ -539,10 +538,7 @@ end # https://github.com/TuringLang/Turing.jl/issues/1725 # sample(model, Gibbs(:z => MH(), :m => HMC(0.01, 4)), 100); chn = sample( - StableRNG(23), - model, - Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), - num_samples, + StableRNG(23), model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), num_samples ) # The number of m variables that have a non-zero value in a sample. num_ms = count(ismissing.(Array(chn[:, (num_zs + 1):end, 1])); dims=2) @@ -592,7 +588,7 @@ end return m .~ Normal(1.0, 1.0) end model = dynamic_model_with_dot_tilde() - sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100) + sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), 100) end @testset "Demo model" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index fd588b8f0a..4a8604ef69 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -19,7 +19,6 @@ using Turing @testset verbose = true "Testing hmc.jl" begin @info "Starting HMC tests" seed = 123 - adbackend = Turing.DEFAULT_ADTYPE @testset "constrained bounded" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @@ -35,7 +34,7 @@ using Turing chain = sample( StableRNG(seed), constrained_test(obs), - HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5) + HMC(1.5, 3),# using a large step size (1.5) 1_000, ) @@ -54,12 +53,7 @@ using Turing return ps end - chain = sample( - StableRNG(seed), - constrained_simplex_test(obs12), - HMC(0.75, 2; adtype=adbackend), - 1000, - ) + chain = sample(StableRNG(seed), constrained_simplex_test(obs12), HMC(0.75, 2), 1000) check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end @@ -71,7 +65,7 @@ using Turing model_f = hmcmatrixsup() n_samples = 1_000 - chain = sample(StableRNG(24), model_f, HMC(0.15, 7; adtype=adbackend), n_samples) + chain = sample(StableRNG(24), model_f, HMC(0.15, 7), n_samples) # Reshape the chain into an array of 2x2 matrices, one per sample. Then compute # the average of the samples, as a matrix r = reshape(Array(chain), n_samples, 2, 2) @@ -125,11 +119,11 @@ using Turing end # Sampling - chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5; adtype=adbackend), 10) + chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5), 10) end @testset "hmcda inference" begin - alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend) + alg1 = HMCDA(500, 0.8, 0.015) res1 = sample(StableRNG(seed), gdemo_default, alg1, 3_000) check_gdemo(res1) end @@ -147,11 +141,11 @@ using Turing end @testset "hmcda constructor" begin - alg = HMCDA(0.8, 0.75; adtype=adbackend) + alg = HMCDA(0.8, 0.75) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "HMCDA" - alg = HMCDA(200, 0.8, 0.75; adtype=adbackend) + alg = HMCDA(200, 0.8, 0.75) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "HMCDA" @@ -160,23 +154,23 @@ using Turing end @testset "nuts inference" begin - alg = NUTS(1000, 0.8; adtype=adbackend) + alg = NUTS(1000, 0.8) res = sample(StableRNG(seed), gdemo_default, alg, 5_000) check_gdemo(res) end @testset "nuts constructor" begin - alg = NUTS(200, 0.65; adtype=adbackend) + alg = NUTS(200, 0.65) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "NUTS" - alg = NUTS(0.65; adtype=adbackend) + alg = NUTS(0.65) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "NUTS" end @testset "check discard" begin - alg = NUTS(100, 0.8; adtype=adbackend) + alg = NUTS(100, 0.8) c1 = sample(StableRNG(seed), gdemo_default, alg, 500; discard_adapt=true) c2 = sample(StableRNG(seed), gdemo_default, alg, 500; discard_adapt=false) @@ -186,9 +180,9 @@ using Turing end @testset "AHMC resize" begin - alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65; adtype=adbackend)) - alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3; adtype=adbackend)) - alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3; adtype=adbackend)) + alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65)) + alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3)) + alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3)) @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa Chains @@ -196,7 +190,7 @@ using Turing # issue #1923 @testset "reproducibility" begin - alg = NUTS(1000, 0.8; adtype=adbackend) + alg = NUTS(1000, 0.8) res1 = sample(StableRNG(seed), gdemo_default, alg, 10) res2 = sample(StableRNG(seed), gdemo_default, alg, 10) res3 = sample(StableRNG(seed), gdemo_default, alg, 10) @@ -212,7 +206,7 @@ using Turing s ~ prior_dist return m ~ Normal(0, sqrt(s)) end - alg = NUTS(1000, 0.8; adtype=adbackend) + alg = NUTS(1000, 0.8) gdemo_default_prior = DynamicPPL.contextualize( demo_hmc_prior(), DynamicPPL.PriorContext() ) @@ -233,7 +227,7 @@ using Turing :warn, "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", ) (:info,) match_mode = :any begin - sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) + sample(demo_warn_initial_params(), NUTS(), 5) end end @@ -243,7 +237,7 @@ using Turing Turing.@addlogprob! -Inf end - @test_throws ErrorException sample(demo_impossible(), NUTS(; adtype=adbackend), 5) + @test_throws ErrorException sample(demo_impossible(), NUTS(), 5) end @testset "(partially) issue: #2095" begin diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index cea361c432..6d628ce3cd 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -16,20 +16,20 @@ using Turing @testset verbose = true "Testing sghmc.jl" begin @testset "sghmc constructor" begin - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} end + @testset "sghmc inference" begin rng = StableRNG(123) - - alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=Turing.DEFAULT_ADTYPE) + alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5) chain = sample(rng, gdemo_default, alg, 10_000) check_gdemo(chain; atol=0.1) end @@ -37,12 +37,12 @@ end @testset "Testing sgld.jl" begin @testset "sgld constructor" begin - alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=Turing.DEFAULT_ADTYPE) + alg = SGLD(; stepsize=PolynomialStepsize(0.25)) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} - alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=Turing.DEFAULT_ADTYPE) + alg = SGLD(; stepsize=PolynomialStepsize(0.25)) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} From 3a43ef710558f3d26255da157c826a4fe197a8e9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 29 May 2025 14:54:18 +0100 Subject: [PATCH 18/18] One more --- test/mcmc/hmc.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 4a8604ef69..0649ca8d39 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -133,9 +133,7 @@ using Turing # explicitly specifying the seeds here. @testset "hmcda+gibbs inference" begin Random.seed!(12345) - alg = Gibbs( - :s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend) - ) + alg = Gibbs(:s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05)) res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000) check_gdemo(res) end