diff --git a/test/ad.jl b/test/ad.jl index f53dd9835..dcfe4ef46 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -237,34 +237,8 @@ end end end -@testset verbose = true "AD / SamplingContext" begin - # AD tests for gradient-based samplers need to be run with SamplingContext - # because samplers can potentially use this to define custom behaviour in - # the tilde-pipeline and thus change the code executed during model - # evaluation. - @testset "adtype=$adtype" for adtype in ADTYPES - @testset "alg=$alg" for alg in [ - HMC(0.1, 10; adtype=adtype), - HMCDA(0.8, 0.75; adtype=adtype), - NUTS(1000, 0.8; adtype=adtype), - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), - ] - @info "Testing AD for $alg" - - @testset "model=$(model.f)" for model in DEMO_MODELS - rng = StableRNG(123) - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - ) - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any - end - end - end -end - @testset verbose = true "AD / GibbsContext" begin - # Gibbs sampling also needs extra AD testing because the models are + # Gibbs sampling needs some extra AD testing because the models are # executed with GibbsContext and a subsetted varinfo. (see e.g. # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in # src/mcmc/gibbs.jl -- the code here mimics what happens in those @@ -283,10 +257,7 @@ end model, varnames, deepcopy(global_vi) ) rng = StableRNG(123) - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) - ) - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any + @test run_ad(model, adtype; test=true, benchmark=false) isa Any end end end