Skip to content

Commit be08026

Browse files
committed
Separate AD tests into their own file
1 parent 0975a91 commit be08026

File tree

4 files changed

+77
-53
lines changed

4 files changed

+77
-53
lines changed

test/ad.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
module TuringADTests
2+
3+
using Turing
4+
using DynamicPPL
5+
using DynamicPPL.TestUtils: DEMO_MODELS
6+
using DynamicPPL.TestUtils.AD: run_ad
7+
using StableRNGs: StableRNG
8+
using Test
9+
using ..Models: gdemo_default
10+
using ..ADUtils: ADTypeCheckContext, adbackends
11+
12+
@testset verbose = true "AD / SamplingContext" begin
13+
# AD tests for gradient-based samplers need to be run with SamplingContext
14+
# because samplers can potentially use this to define custom behaviour in
15+
# the tilde-pipeline and thus change the code executed during model
16+
# evaluation.
17+
@testset "adtype=$adtype" for adtype in adbackends
18+
@testset "alg=$alg" for alg in [
19+
HMC(0.1, 10; adtype=adtype),
20+
HMCDA(0.8, 0.75; adtype=adtype),
21+
NUTS(1000, 0.8; adtype=adtype),
22+
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
23+
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
24+
]
25+
@info "Testing AD for $alg"
26+
27+
@testset "model=$(model.f)" for model in DEMO_MODELS
28+
rng = StableRNG(123)
29+
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
30+
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
31+
end
32+
end
33+
34+
@testset "Check ADType" begin
35+
seed = 123
36+
alg = HMC(0.1, 10; adtype=adtype)
37+
m = DynamicPPL.contextualize(
38+
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
39+
)
40+
# These will error if the adbackend being used is not the one set.
41+
sample(StableRNG(seed), m, alg, 10)
42+
end
43+
end
44+
end
45+
46+
@testset verbose = true "AD / GibbsContext" begin
47+
# Gibbs sampling also needs extra AD testing because the models are
48+
# executed with GibbsContext and a subsetted varinfo. (see e.g.
49+
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
50+
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
51+
# functions)
52+
@testset "adtype=$adtype" for adtype in adbackends
53+
@testset "model=$(model.f)" for model in DEMO_MODELS
54+
# All the demo models have variables `s` and `m`, so we'll pretend
55+
# that we're using a Gibbs sampler where both of them are sampled
56+
# with a gradient-based sampler (say HMC(0.1, 10)).
57+
# This means we need to construct one with only `s`, and one model with
58+
# only `m`.
59+
global_vi = DynamicPPL.VarInfo(model)
60+
@testset for varnames in ([@varname(s)], [@varname(m)])
61+
@info "Testing Gibbs AD with model=$(model.f), varnames=$varnames"
62+
conditioned_model = Turing.Inference.make_conditional(
63+
model, varnames, deepcopy(global_vi)
64+
)
65+
rng = StableRNG(123)
66+
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10)))
67+
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
68+
end
69+
end
70+
end
71+
end
72+
73+
end # module

test/mcmc/hmc.jl

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import ..ADUtils
77
using Bijectors: Bijectors
88
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
99
using DynamicPPL: DynamicPPL, Sampler
10-
using DynamicPPL.TestUtils.AD: run_ad
11-
using DynamicPPL.TestUtils: DEMO_MODELS
1210
import ForwardDiff
1311
using HypothesisTests: ApproximateTwoSampleKSTest, pvalue
1412
import ReverseDiff
@@ -20,37 +18,6 @@ import Mooncake
2018
using Test: @test, @test_logs, @testset, @test_throws
2119
using Turing
2220

23-
@testset "AD / hmc.jl" begin
24-
# AD tests need to be run with SamplingContext because samplers can potentially
25-
# use this to define custom behaviour in the tilde-pipeline and thus change the
26-
# code executed during model evaluation.
27-
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
28-
@testset "alg=$alg" for alg in [
29-
HMC(0.1, 10; adtype=adtype),
30-
HMCDA(0.8, 0.75; adtype=adtype),
31-
NUTS(1000, 0.8; adtype=adtype),
32-
]
33-
@info "Testing AD for $alg"
34-
35-
@testset "model=$(model.f)" for model in DEMO_MODELS
36-
rng = StableRNG(123)
37-
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
38-
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
39-
end
40-
end
41-
42-
@testset "Check ADType" begin
43-
seed = 123
44-
alg = HMC(0.1, 10; adtype=adtype)
45-
m = DynamicPPL.contextualize(
46-
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
47-
)
48-
# These will error if the adbackend being used is not the one set.
49-
sample(StableRNG(seed), m, alg, 10)
50-
end
51-
end
52-
end
53-
5421
@testset verbose = true "Testing hmc.jl" begin
5522
@info "Starting HMC tests"
5623
seed = 123

test/mcmc/sghmc.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,6 @@ import Mooncake
1515
using Test: @test, @testset
1616
using Turing
1717

18-
@testset "AD / sghmc.jl" begin
19-
# AD tests need to be run with SamplingContext because samplers can potentially
20-
# use this to define custom behaviour in the tilde-pipeline and thus change the
21-
# code executed during model evaluation.
22-
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
23-
@testset "alg=$alg" for alg in [
24-
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
25-
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
26-
]
27-
@info "Testing AD for $alg"
28-
29-
@testset "model=$(model.f)" for model in DEMO_MODELS
30-
rng = StableRNG(123)
31-
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
32-
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
33-
end
34-
end
35-
end
36-
end
37-
3818
@testset verbose = true "Testing sghmc.jl" begin
3919
@testset "sghmc constructor" begin
4020
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ end
3838
@timeit_include("Aqua.jl")
3939
end
4040

41+
@testset "AD" verbose = true begin
42+
@timeit_include("ad.jl")
43+
end
44+
4145
@testset "essential" verbose = true begin
4246
@timeit_include("essential/container.jl")
4347
end

0 commit comments

Comments
 (0)