diff --git a/test/Project.toml b/test/Project.toml index a15680418d..d6ed2c79fd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" @@ -52,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36" +DynamicPPL = "0.36.6" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" diff --git a/test/test_utils/ad_utils.jl b/test/ad.jl similarity index 52% rename from test/test_utils/ad_utils.jl rename to test/ad.jl index 309276407a..30cfc13d7a 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/ad.jl @@ -1,18 +1,14 @@ -module ADUtils +module TuringADTests -using ForwardDiff: ForwardDiff -using Pkg: Pkg +using Turing +using DynamicPPL +using DynamicPPL.TestUtils: DEMO_MODELS +using DynamicPPL.TestUtils.AD: run_ad using Random: Random -using ReverseDiff: ReverseDiff -using Mooncake: Mooncake -using Test: Test -using Turing: Turing -using Turing: DynamicPPL - -export ADTypeCheckContext, adbackends - -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -# Stuff for checking that the right AD backend is being used. +using StableRNGs: StableRNG +using Test +using ..Models: gdemo_default +import ForwardDiff, ReverseDiff, Mooncake """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) @@ -178,16 +174,122 @@ function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, l return logp, vi end -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -# List of AD backends to test. - """ All the ADTypes on which we want to run the tests. """ -adbackends = [ +ADTYPES = [ Turing.AutoForwardDiff(), Turing.AutoReverseDiff(; compile=false), Turing.AutoMooncake(; config=nothing), ] +# Check that ADTypeCheckContext itself works as expected. +@testset "ADTypeCheckContext" begin + @model test_model() = x ~ Normal(0, 1) + tm = test_model() + adtypes = ( + Turing.AutoForwardDiff(), + Turing.AutoReverseDiff(), + # Don't need to test Mooncake as it doesn't use tracer types + ) + for actual_adtype in adtypes + sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) + for expected_adtype in adtypes + contextualised_tm = DynamicPPL.contextualize( + tm, ADTypeCheckContext(expected_adtype, tm.context) + ) + @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin + if actual_adtype == expected_adtype + # Check that this does not throw an error. + Turing.sample(contextualised_tm, sampler, 2) + else + @test_throws AbstractWrongADBackendError Turing.sample( + contextualised_tm, sampler, 2 + ) + end + end + end + end +end + +@testset verbose = true "AD / ADTypeCheckContext" begin + # This testset ensures that samplers or optimisers don't accidentally + # override the AD backend set in it. + @testset "adtype=$adtype" for adtype in ADTYPES + seed = 123 + alg = HMC(0.1, 10; adtype=adtype) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + sample(StableRNG(seed), m, alg, 10) + maximum_likelihood(m; adtype=adtype) + maximum_a_posteriori(m; adtype=adtype) + end +end + +@testset verbose = true "AD / SamplingContext" begin + # AD tests for gradient-based samplers need to be run with SamplingContext + # because samplers can potentially use this to define custom behaviour in + # the tilde-pipeline and thus change the code executed during model + # evaluation. + @testset "adtype=$adtype" for adtype in ADTYPES + @testset "alg=$alg" for alg in [ + HMC(0.1, 10; adtype=adtype), + HMCDA(0.8, 0.75; adtype=adtype), + NUTS(1000, 0.8; adtype=adtype), + SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), + SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), + ] + @info "Testing AD for $alg" + + @testset "model=$(model.f)" for model in DEMO_MODELS + rng = StableRNG(123) + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + end + end + end +end + +@testset verbose = true "AD / GibbsContext" begin + # Gibbs sampling also needs extra AD testing because the models are + # executed with GibbsContext and a subsetted varinfo. (see e.g. + # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in + # src/mcmc/gibbs.jl -- the code here mimics what happens in those + # functions) + @testset "adtype=$adtype" for adtype in ADTYPES + @testset "model=$(model.f)" for model in DEMO_MODELS + # All the demo models have variables `s` and `m`, so we'll pretend + # that we're using a Gibbs sampler where both of them are sampled + # with a gradient-based sampler (say HMC(0.1, 10)). + # This means we need to construct one with only `s`, and one model with + # only `m`. + global_vi = DynamicPPL.VarInfo(model) + @testset for varnames in ([@varname(s)], [@varname(m)]) + @info "Testing Gibbs AD with model=$(model.f), varnames=$varnames" + conditioned_model = Turing.Inference.make_conditional( + model, varnames, deepcopy(global_vi) + ) + rng = StableRNG(123) + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + end + end + end end + +@testset verbose = true "AD / Gibbs sampling" begin + # Make sure that Gibbs sampling doesn't fall over when using AD. + @testset "adtype=$adtype" for adtype in ADTYPES + spl = Gibbs( + @varname(s) => HMC(0.1, 10; adtype=adtype), + @varname(m) => HMC(0.1, 10; adtype=adtype), + ) + @testset "model=$(model.f)" for model in DEMO_MODELS + @test sample(model, spl, 2) isa Any + end + end +end + +end # module diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 55d989b154..78d79a5480 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -2,7 +2,6 @@ module InferenceTests using ..Models: gdemo_d, gdemo_default using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL @@ -17,8 +16,9 @@ import Mooncake using Test: @test, @test_throws, @testset using Turing -@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends - @info "Starting Inference.jl tests with $adbackend" +@testset verbose = true "Testing Inference.jl" begin + @info "Starting Inference.jl tests" + seed = 23 @testset "threaded sampling" begin @@ -27,12 +27,12 @@ using Turing model = gdemo_default samplers = ( - HMC(0.1, 7; adtype=adbackend), + HMC(0.1, 7), PG(10), IS(), MH(), - Gibbs(:s => PG(3), :m => HMC(0.4, 8; adtype=adbackend)), - Gibbs(:s => HMC(0.1, 5; adtype=adbackend), :m => ESS()), + Gibbs(:s => PG(3), :m => HMC(0.4, 8)), + Gibbs(:s => HMC(0.1, 5), :m => ESS()), ) for sampler in samplers Random.seed!(5) @@ -44,7 +44,7 @@ using Turing @test chain1.value == chain2.value end - # Should also be stable with am explicit RNG + # Should also be stable with an explicit RNG seed = 5 rng = Random.MersenneTwister(seed) for sampler in samplers @@ -61,27 +61,22 @@ using Turing # Smoke test for default sample call. @testset "gdemo_default" begin chain = sample( - StableRNG(seed), - gdemo_default, - HMC(0.1, 7; adtype=adbackend), - MCMCThreads(), - 1_000, - 4, + StableRNG(seed), gdemo_default, HMC(0.1, 7), MCMCThreads(), 1_000, 4 ) check_gdemo(chain) # run sampler: progress logging should be disabled and # it should return a Chains object - sampler = Sampler(HMC(0.1, 7; adtype=adbackend)) + sampler = Sampler(HMC(0.1, 7)) chains = sample(StableRNG(seed), gdemo_default, sampler, MCMCThreads(), 10, 4) @test chains isa MCMCChains.Chains end end @testset "chain save/resume" begin - alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) + alg1 = HMCDA(1000, 0.65, 0.15) alg2 = PG(20) - alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4; adtype=adbackend)) + alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4)) chn1 = sample(StableRNG(seed), gdemo_default, alg1, 10_000; save_state=true) check_gdemo(chn1) @@ -260,7 +255,7 @@ using Turing smc = SMC() pg = PG(10) - gibbs = Gibbs(:p => HMC(0.2, 3; adtype=adbackend), :x => PG(10)) + gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10)) chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) @@ -273,22 +268,17 @@ using Turing @testset "forbid global" begin xs = [1.5 2.0] - # xx = 1 @model function fggibbstest(xs) s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) - # xx ~ Normal(m, sqrt(s)) # this is illegal - for i in 1:length(xs) xs[i] ~ Normal(m, sqrt(s)) - # for xx in xs - # xx ~ Normal(m, sqrt(s)) end return s, m end - gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8; adtype=adbackend)) + gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8)) chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2) end @@ -353,7 +343,7 @@ using Turing ) end - # TODO(mhauru) What is this testing? Why does it not use the looped-over adbackend? + # TODO(mhauru) What is this testing? Why does it use a different adbackend? @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @@ -382,9 +372,7 @@ using Turing end end - chain = sample( - StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10; adtype=adbackend), 4000 - ) + chain = sample(StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10), 4000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]) end @@ -415,7 +403,7 @@ using Turing end @testset "sample" begin - alg = Gibbs(:m => HMC(0.2, 3; adtype=adbackend), :s => PG(10)) + alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10)) chn = sample(StableRNG(seed), gdemo_default, alg, 10) end @@ -427,7 +415,7 @@ using Turing return s, m end - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) x = randn(100) res = sample(StableRNG(seed), vdemo1(x), alg, 10) @@ -442,7 +430,7 @@ using Turing # Vector assumptions N = 10 - alg = HMC(0.2, 4; adtype=adbackend) + alg = HMC(0.2, 4) @model function vdemo3() x = Vector{Real}(undef, N) @@ -497,7 +485,7 @@ using Turing return s, m end - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) x = randn(100) res = sample(StableRNG(seed), vdemo1(x), alg, 10) @@ -507,12 +495,12 @@ using Turing end D = 2 - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10) # Vector assumptions N = 10 - alg = HMC(0.2, 4; adtype=adbackend) + alg = HMC(0.2, 4) @model function vdemo3() x = Vector{Real}(undef, N) @@ -559,7 +547,7 @@ using Turing @testset "Type parameters" begin N = 10 - alg = HMC(0.01, 5; adtype=adbackend) + alg = HMC(0.01, 5) x = randn(1000) @model function vdemo1(::Type{T}=Float64) where {T} x = Vector{T}(undef, N) diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index f909df8ef1..7651752ba6 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -1,6 +1,5 @@ module AbstractMCMCTests -import ..ADUtils using AbstractMCMC: AbstractMCMC using AdvancedMH: AdvancedMH using Distributions: sample @@ -110,38 +109,38 @@ function test_initial_params( end end -@testset "External samplers" begin +@testset verbose = true "External samplers" begin @testset "AdvancedHMC.jl" begin - @testset "adtype=$adtype" for adtype in ADUtils.adbackends - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - # Need some functionality to initialize the sampler. - # TODO: Remove this once the constructors in the respective packages become "lazy". - sampler = initialize_nuts(model) - sampler_ext = DynamicPPL.Sampler( - externalsampler(sampler; adtype, unconstrained=true) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + adtype = Turing.DEFAULT_ADTYPE + + # Need some functionality to initialize the sampler. + # TODO: Remove this once the constructors in the respective packages become "lazy". + sampler = initialize_nuts(model) + sampler_ext = DynamicPPL.Sampler( + externalsampler(sampler; adtype, unconstrained=true) + ) + # FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment. + # @testset "initial_params" begin + # test_initial_params(model, sampler_ext; n_adapts=0) + # end + + sample_kwargs = ( + n_adapts=1_000, + discard_initial=1_000, + # FIXME: Remove this once we can run `test_initial_params` above. + initial_params=DynamicPPL.VarInfo(model)[:], + ) + + @testset "inference" begin + DynamicPPL.TestUtils.test_sampler( + [model], + sampler_ext, + 2_000; + rtol=0.2, + sampler_name="AdvancedHMC", + sample_kwargs..., ) - # FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment. - # @testset "initial_params" begin - # test_initial_params(model, sampler_ext; n_adapts=0) - # end - - sample_kwargs = ( - n_adapts=1_000, - discard_initial=1_000, - # FIXME: Remove this once we can run `test_initial_params` above. - initial_params=DynamicPPL.VarInfo(model)[:], - ) - - @testset "inference" begin - DynamicPPL.TestUtils.test_sampler( - [model], - sampler_ext, - 2_000; - rtol=0.2, - sampler_name="AdvancedHMC", - sample_kwargs..., - ) - end end end end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index ec9d42048a..2070c697c4 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -7,7 +7,6 @@ using ..NumericalTests: check_gdemo, check_numerical, two_sample_test -import ..ADUtils import Combinatorics using AbstractMCMC: AbstractMCMC using Distributions: InverseGamma, Normal @@ -34,19 +33,7 @@ function check_transition_varnames(transition::Turing.Inference.Transition, pare end end -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)}, - DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_observe_matrix_index)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::DynamicPPL.Model) = true - -@testset "GibbsContext" begin +@testset verbose = true "GibbsContext" begin @testset "type stability" begin struct Wrapper{T<:Real} a::T @@ -384,23 +371,23 @@ end @test wuc.non_warmup_count == (num_samples - 1) * num_reps end -@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @info "Starting Gibbs tests with $adbackend" +@testset verbose = true "Testing gibbs.jl" begin + @info "Starting Gibbs tests" @testset "Gibbs constructors" begin # Create Gibbs samplers with various configurations and ways of passing the # arguments, and run them all on the `gdemo_default` model, see that nothing breaks. N = 10 # Two variables being sampled by one sampler. - s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5)) s2 = Gibbs((@varname(s), :m) => PG(10)) # As above but different samplers and using kwargs. - s3 = Gibbs(:s => CSMC(3), :m => HMCDA(200, 0.65, 0.15; adtype=adbackend)) - s4 = Gibbs(@varname(s) => HMC(0.1, 5; adtype=adbackend), @varname(m) => ESS()) + s3 = Gibbs(:s => CSMC(3), :m => HMCDA(200, 0.65, 0.15)) + s4 = Gibbs(@varname(s) => HMC(0.1, 5), @varname(m) => ESS()) # Multiple instnaces of the same sampler. This implements running, in this case, # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. s5 = begin - hmc = HMC(0.1, 5; adtype=adbackend) + hmc = HMC(0.1, 5) pg = PG(10) vns = @varname(s) vnm = @varname(m) @@ -408,7 +395,7 @@ end end # Same thing but using RepeatSampler. s6 = Gibbs( - @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), + @varname(s) => RepeatSampler(HMC(0.1, 5), 3), @varname(m) => RepeatSampler(PG(10), 2), ) for s in (s1, s2, s3, s4, s5, s6) @@ -430,7 +417,7 @@ end # posterior mean. @testset "Gibbs inference" begin @testset "CSMC and HMC on gdemo" begin - alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -438,7 +425,7 @@ end end @testset "MH and HMCDA on gdemo" begin - alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @@ -459,7 +446,7 @@ end @testset "PG and HMC on MoGtest_default" begin gibbs = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3), ) chain = sample(MoGtest_default, gibbs, 2_000) check_MoGtest_default(chain; atol=0.15) @@ -472,8 +459,8 @@ end (@varname(s), @varname(m)) => MH(), @varname(m) => ESS(), @varname(s) => RepeatSampler(MH(), 3), - @varname(m) => HMC(0.2, 4; adtype=adbackend), - (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend), + @varname(m) => HMC(0.2, 4), + (@varname(m), @varname(s)) => HMC(0.2, 4), ) chain = sample(gdemo(1.5, 2.0), alg, 500) check_gdemo(chain; atol=0.15) @@ -483,7 +470,7 @@ end gibbs = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), (@varname(z1), @varname(z2)) => PG(15), - (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3), (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), (@varname(mu1)) => ESS(), (@varname(mu2)) => ESS(), @@ -521,7 +508,7 @@ end return nothing end - alg = Gibbs(:s => MH(), :m => HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMC(0.2, 4)) sample(model, alg, 100; callback=callback) end @@ -551,10 +538,7 @@ end # https://github.com/TuringLang/Turing.jl/issues/1725 # sample(model, Gibbs(:z => MH(), :m => HMC(0.01, 4)), 100); chn = sample( - StableRNG(23), - model, - Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), - num_samples, + StableRNG(23), model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), num_samples ) # The number of m variables that have a non-zero value in a sample. num_ms = count(ismissing.(Array(chn[:, (num_zs + 1):end, 1])); dims=2) @@ -597,41 +581,27 @@ end @model function dynamic_model_with_dot_tilde( num_zs=10, ::Type{M}=Vector{Float64} ) where {M} - z = M(undef, num_zs) + z = Vector{Int}(undef, num_zs) z .~ Poisson(1.0) num_ms = sum(z) m = M(undef, num_ms) return m .~ Normal(1.0, 1.0) end model = dynamic_model_with_dot_tilde() - # TODO(mhauru) This is broken because of - # https://github.com/TuringLang/DynamicPPL.jl/issues/700. - @test_broken ( - sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100); - true - ) + sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), 100) end - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "Demo model" begin + @testset verbose = true "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) samplers = [ Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => NUTS()), Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => HMC(0.01, 4)), Turing.Gibbs(@varname(s) => NUTS(), @varname(m) => ESS()), + Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()), + Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)), ] - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Gibbs(@varname(s) => HMC(0.01, 4), @varname(m) => MH()), - Turing.Gibbs(@varname(s) => MH(), @varname(m) => HMC(0.01, 4)), - ], - ) - end - @testset "$sampler" for sampler in samplers # Check that taking steps performs as expected. rng = Random.default_rng() @@ -846,7 +816,7 @@ end check_MoGtest_default_z_vector(chain; atol=0.2) end - @testset "externsalsampler" begin + @testset "externalsampler" begin @model function demo_gibbs_external() m1 ~ Normal() m2 ~ Normal() diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 4d884b1637..0649ca8d39 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -1,9 +1,7 @@ module HMCTests using ..Models: gdemo_default -using ..ADUtils: ADTypeCheckContext using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample using DynamicPPL: DynamicPPL, Sampler @@ -18,8 +16,8 @@ import Mooncake using Test: @test, @test_logs, @testset, @test_throws using Turing -@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends - @info "Starting HMC tests with $adbackend" +@testset verbose = true "Testing hmc.jl" begin + @info "Starting HMC tests" seed = 123 @testset "constrained bounded" begin @@ -36,7 +34,7 @@ using Turing chain = sample( StableRNG(seed), constrained_test(obs), - HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5) + HMC(1.5, 3),# using a large step size (1.5) 1_000, ) @@ -55,22 +53,11 @@ using Turing return ps end - chain = sample( - StableRNG(seed), - constrained_simplex_test(obs12), - HMC(0.75, 2; adtype=adbackend), - 1000, - ) + chain = sample(StableRNG(seed), constrained_simplex_test(obs12), HMC(0.75, 2), 1000) check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end - @testset "hmc reverse diff" begin - alg = HMC(0.1, 10; adtype=adbackend) - res = sample(StableRNG(seed), gdemo_default, alg, 4_000) - check_gdemo(res; rtol=0.1) - end - # Test the sampling of a matrix-value distribution. @testset "matrix support" begin dist = Wishart(7, [1 0.5; 0.5 1]) @@ -78,7 +65,7 @@ using Turing model_f = hmcmatrixsup() n_samples = 1_000 - chain = sample(StableRNG(24), model_f, HMC(0.15, 7; adtype=adbackend), n_samples) + chain = sample(StableRNG(24), model_f, HMC(0.15, 7), n_samples) # Reshape the chain into an array of 2x2 matrices, one per sample. Then compute # the average of the samples, as a matrix r = reshape(Array(chain), n_samples, 2, 2) @@ -132,11 +119,11 @@ using Turing end # Sampling - chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5; adtype=adbackend), 10) + chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5), 10) end @testset "hmcda inference" begin - alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend) + alg1 = HMCDA(500, 0.8, 0.015) res1 = sample(StableRNG(seed), gdemo_default, alg1, 3_000) check_gdemo(res1) end @@ -146,19 +133,17 @@ using Turing # explicitly specifying the seeds here. @testset "hmcda+gibbs inference" begin Random.seed!(12345) - alg = Gibbs( - :s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend) - ) + alg = Gibbs(:s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05)) res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000) check_gdemo(res) end @testset "hmcda constructor" begin - alg = HMCDA(0.8, 0.75; adtype=adbackend) + alg = HMCDA(0.8, 0.75) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "HMCDA" - alg = HMCDA(200, 0.8, 0.75; adtype=adbackend) + alg = HMCDA(200, 0.8, 0.75) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "HMCDA" @@ -167,23 +152,23 @@ using Turing end @testset "nuts inference" begin - alg = NUTS(1000, 0.8; adtype=adbackend) + alg = NUTS(1000, 0.8) res = sample(StableRNG(seed), gdemo_default, alg, 5_000) check_gdemo(res) end @testset "nuts constructor" begin - alg = NUTS(200, 0.65; adtype=adbackend) + alg = NUTS(200, 0.65) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "NUTS" - alg = NUTS(0.65; adtype=adbackend) + alg = NUTS(0.65) sampler = Sampler(alg) @test DynamicPPL.alg_str(sampler) == "NUTS" end @testset "check discard" begin - alg = NUTS(100, 0.8; adtype=adbackend) + alg = NUTS(100, 0.8) c1 = sample(StableRNG(seed), gdemo_default, alg, 500; discard_adapt=true) c2 = sample(StableRNG(seed), gdemo_default, alg, 500; discard_adapt=false) @@ -193,9 +178,9 @@ using Turing end @testset "AHMC resize" begin - alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65; adtype=adbackend)) - alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3; adtype=adbackend)) - alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3; adtype=adbackend)) + alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65)) + alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3)) + alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3)) @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa Chains @@ -203,7 +188,7 @@ using Turing # issue #1923 @testset "reproducibility" begin - alg = NUTS(1000, 0.8; adtype=adbackend) + alg = NUTS(1000, 0.8) res1 = sample(StableRNG(seed), gdemo_default, alg, 10) res2 = sample(StableRNG(seed), gdemo_default, alg, 10) res3 = sample(StableRNG(seed), gdemo_default, alg, 10) @@ -211,20 +196,20 @@ using Turing end @testset "prior" begin + # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance + # which means that it's _very_ difficult to find a good tolerance in the test below:) + prior_dist = truncated(Normal(3, 1); lower=0) + @model function demo_hmc_prior() - # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance - # which means that it's _very_ difficult to find a good tolerance in the test below:) - s ~ truncated(Normal(3, 1); lower=0) + s ~ prior_dist return m ~ Normal(0, sqrt(s)) end - alg = NUTS(1000, 0.8; adtype=adbackend) + alg = NUTS(1000, 0.8) gdemo_default_prior = DynamicPPL.contextualize( demo_hmc_prior(), DynamicPPL.PriorContext() ) chain = sample(gdemo_default_prior, alg, 5_000; initial_params=[3.0, 0.0]) - check_numerical( - chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2 - ) + check_numerical(chain, [:s, :m], [mean(prior_dist), 0]; atol=0.2) end @testset "warning for difficult init params" begin @@ -240,7 +225,7 @@ using Turing :warn, "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", ) (:info,) match_mode = :any begin - sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) + sample(demo_warn_initial_params(), NUTS(), 5) end end @@ -250,7 +235,7 @@ using Turing Turing.@addlogprob! -Inf end - @test_throws ErrorException sample(demo_impossible(), NUTS(; adtype=adbackend), 5) + @test_throws ErrorException sample(demo_impossible(), NUTS(), 5) end @testset "(partially) issue: #2095" begin @@ -292,8 +277,8 @@ using Turing # Extract the `x` like this because running `generated_quantities` was how # the issue was discovered, hence we also want to make sure that it works. - results = generated_quantities(model, chain) - results_prior = generated_quantities(model, chain_prior) + results = returned(model, chain) + results_prior = returned(model, chain_prior) # Make sure none of the samples in the chains resulted in errors. @test all(!isnothing, results) @@ -315,15 +300,6 @@ using Turing @test Turing.Inference.getstepsize(spl, hmc_state) isa Float64 end end - - @testset "Check ADType" begin - alg = HMC(0.1, 10; adtype=adbackend) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(StableRNG(seed), m, alg, 10) - end end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index c878f755de..6d628ce3cd 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -2,7 +2,9 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo -import ..ADUtils +using DynamicPPL.TestUtils.AD: run_ad +using DynamicPPL.TestUtils: DEMO_MODELS +using DynamicPPL: DynamicPPL using Distributions: sample import ForwardDiff using LinearAlgebra: dot @@ -12,35 +14,35 @@ import Mooncake using Test: @test, @testset using Turing -@testset "Testing sghmc.jl with $adbackend" for adbackend in ADUtils.adbackends +@testset verbose = true "Testing sghmc.jl" begin @testset "sghmc constructor" begin - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} end + @testset "sghmc inference" begin rng = StableRNG(123) - - alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend) + alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5) chain = sample(rng, gdemo_default, alg, 10_000) check_gdemo(chain; atol=0.1) end end -@testset "Testing sgld.jl with $adbackend" for adbackend in ADUtils.adbackends +@testset "Testing sgld.jl" begin @testset "sgld constructor" begin - alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) + alg = SGLD(; stepsize=PolynomialStepsize(0.25)) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} - alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) + alg = SGLD(; stepsize=PolynomialStepsize(0.25)) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 9894d621ce..cf692cba25 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -1,7 +1,6 @@ module OptimisationTests using ..Models: gdemo, gdemo_default -using ..ADUtils: ADUtils using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL @@ -624,16 +623,6 @@ using Turing @assert get(result, :c) == (; :c => Array{Float64}[]) end - @testset "ADType test with $adbackend" for adbackend in ADUtils.adbackends - Random.seed!(222) - m = DynamicPPL.contextualize( - gdemo_default, ADUtils.ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - maximum_likelihood(m; adtype=adbackend) - maximum_a_posteriori(m; adtype=adbackend) - end - @testset "Collinear coeftable" begin xs = [-1.0, 0.0, 1.0] ys = [0.0, 0.0, 0.0] diff --git a/test/runtests.jl b/test/runtests.jl index 47b714188e..69f8045dd8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,9 +9,8 @@ import Turing # Fix the global Random.seed for reproducibility. seed!(23) -include(pkgdir(Turing) * "/test/test_utils/models.jl") -include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl") -include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl") +include("test_utils/models.jl") +include("test_utils/numerical_tests.jl") Turing.setprogress!(false) included_paths, excluded_paths = parse_args(ARGS) @@ -30,14 +29,14 @@ macro timeit_include(path::AbstractString) end @testset "Turing" verbose = true begin - @testset "Test utils" begin - @timeit_include("test_utils/test_utils.jl") - end - @testset "Aqua" begin @timeit_include("Aqua.jl") end + @testset "AD" verbose = true begin + @timeit_include("ad.jl") + end + @testset "essential" verbose = true begin @timeit_include("essential/container.jl") end diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl deleted file mode 100644 index 243a80b881..0000000000 --- a/test/test_utils/test_utils.jl +++ /dev/null @@ -1,41 +0,0 @@ -"""Module for testing the test utils themselves.""" -module TestUtilsTests - -using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff -using Test: @test, @testset, @test_throws -using Turing: Turing -using Turing: DynamicPPL - -# Check that the ADTypeCheckContext works as expected. -@testset "ADTypeCheckContext" begin - Turing.@model test_model() = x ~ Turing.Normal(0, 1) - tm = test_model() - adtypes = ( - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), - ) - for actual_adtype in adtypes - sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) - for expected_adtype in adtypes - contextualised_tm = DynamicPPL.contextualize( - tm, ADTypeCheckContext(expected_adtype, tm.context) - ) - @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin - if actual_adtype == expected_adtype - # Check that this does not throw an error. - Turing.sample(contextualised_tm, sampler, 2) - else - @test_throws AbstractWrongADBackendError Turing.sample( - contextualised_tm, sampler, 2 - ) - end - end - end - end -end - -end