|
| 1 | +module TuringADTests |
| 2 | + |
| 3 | +using Turing |
| 4 | +using DynamicPPL |
| 5 | +using DynamicPPL.TestUtils: DEMO_MODELS |
| 6 | +using DynamicPPL.TestUtils.AD: run_ad |
| 7 | +using StableRNGs: StableRNG |
| 8 | +using Test |
| 9 | +using ..Models: gdemo_default |
| 10 | +using ..ADUtils: ADTypeCheckContext, adbackends |
| 11 | + |
| 12 | +@testset verbose = true "AD / SamplingContext" begin |
| 13 | + # AD tests for gradient-based samplers need to be run with SamplingContext |
| 14 | + # because samplers can potentially use this to define custom behaviour in |
| 15 | + # the tilde-pipeline and thus change the code executed during model |
| 16 | + # evaluation. |
| 17 | + @testset "adtype=$adtype" for adtype in adbackends |
| 18 | + @testset "alg=$alg" for alg in [ |
| 19 | + HMC(0.1, 10; adtype=adtype), |
| 20 | + HMCDA(0.8, 0.75; adtype=adtype), |
| 21 | + NUTS(1000, 0.8; adtype=adtype), |
| 22 | + SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), |
| 23 | + SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), |
| 24 | + ] |
| 25 | + @info "Testing AD for $alg" |
| 26 | + |
| 27 | + @testset "model=$(model.f)" for model in DEMO_MODELS |
| 28 | + rng = StableRNG(123) |
| 29 | + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) |
| 30 | + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any |
| 31 | + end |
| 32 | + end |
| 33 | + |
| 34 | + @testset "Check ADType" begin |
| 35 | + seed = 123 |
| 36 | + alg = HMC(0.1, 10; adtype=adtype) |
| 37 | + m = DynamicPPL.contextualize( |
| 38 | + gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context) |
| 39 | + ) |
| 40 | + # These will error if the adbackend being used is not the one set. |
| 41 | + sample(StableRNG(seed), m, alg, 10) |
| 42 | + end |
| 43 | + end |
| 44 | +end |
| 45 | + |
| 46 | +@testset verbose = true "AD / GibbsContext" begin |
| 47 | + # Gibbs sampling also needs extra AD testing because the models are |
| 48 | + # executed with GibbsContext and a subsetted varinfo. (see e.g. |
| 49 | + # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in |
| 50 | + # src/mcmc/gibbs.jl -- the code here mimics what happens in those |
| 51 | + # functions) |
| 52 | + @testset "adtype=$adtype" for adtype in adbackends |
| 53 | + @testset "model=$(model.f)" for model in DEMO_MODELS |
| 54 | + # All the demo models have variables `s` and `m`, so we'll pretend |
| 55 | + # that we're using a Gibbs sampler where both of them are sampled |
| 56 | + # with a gradient-based sampler (say HMC(0.1, 10)). |
| 57 | + # This means we need to construct one with only `s`, and one model with |
| 58 | + # only `m`. |
| 59 | + global_vi = DynamicPPL.VarInfo(model) |
| 60 | + @testset for varnames in ([@varname(s)], [@varname(m)]) |
| 61 | + @info "Testing Gibbs AD with model=$(model.f), varnames=$varnames" |
| 62 | + conditioned_model = Turing.Inference.make_conditional( |
| 63 | + model, varnames, deepcopy(global_vi) |
| 64 | + ) |
| 65 | + rng = StableRNG(123) |
| 66 | + ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) |
| 67 | + @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any |
| 68 | + end |
| 69 | + end |
| 70 | + end |
| 71 | +end |
| 72 | + |
| 73 | +end # module |
0 commit comments