Skip to content

Commit 1a70627

Browse files
authored
Remove AD backend loops in test suite (#2564)
* Remove AD backend loops * Add AD tests to sghmc.jl * Remove AD loops from HMC and SGHMC * Handle Inference.jl * Remove AD loop in Gibbs * Fix broken test * Mark testsets as verbose=true * Separate AD tests into their own file * Finish AD tests * No I did not mean to comment out all other tests * Move ADTypeCheckContexts and optimisation tests to ad.jl as well * Remove stray test_utils file * Remove has_dot_assume check * Remove TODO comment about Mooncake * Test Gibbs sampling * Fix test * Remove unnecessary DEFAULT_ADTYPEs * One more
1 parent dea5d19 commit 1a70627

File tree

10 files changed

+242
-257
lines changed

10 files changed

+242
-257
lines changed

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
34
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
45
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
@@ -52,7 +53,7 @@ Combinatorics = "1"
5253
Distributions = "0.25"
5354
DistributionsAD = "0.6.3"
5455
DynamicHMC = "2.1.6, 3.0"
55-
DynamicPPL = "0.36"
56+
DynamicPPL = "0.36.6"
5657
FiniteDifferences = "0.10.8, 0.11, 0.12"
5758
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
5859
HypothesisTests = "0.11"

test/test_utils/ad_utils.jl renamed to test/ad.jl

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
module ADUtils
1+
module TuringADTests
22

3-
using ForwardDiff: ForwardDiff
4-
using Pkg: Pkg
3+
using Turing
4+
using DynamicPPL
5+
using DynamicPPL.TestUtils: DEMO_MODELS
6+
using DynamicPPL.TestUtils.AD: run_ad
57
using Random: Random
6-
using ReverseDiff: ReverseDiff
7-
using Mooncake: Mooncake
8-
using Test: Test
9-
using Turing: Turing
10-
using Turing: DynamicPPL
11-
12-
export ADTypeCheckContext, adbackends
13-
14-
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
15-
# Stuff for checking that the right AD backend is being used.
8+
using StableRNGs: StableRNG
9+
using Test
10+
using ..Models: gdemo_default
11+
import ForwardDiff, ReverseDiff, Mooncake
1612

1713
"""Element types that are always valid for a VarInfo regardless of ADType."""
1814
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
@@ -178,16 +174,122 @@ function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, l
178174
return logp, vi
179175
end
180176

181-
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
182-
# List of AD backends to test.
183-
184177
"""
185178
All the ADTypes on which we want to run the tests.
186179
"""
187-
adbackends = [
180+
ADTYPES = [
188181
Turing.AutoForwardDiff(),
189182
Turing.AutoReverseDiff(; compile=false),
190183
Turing.AutoMooncake(; config=nothing),
191184
]
192185

186+
# Check that ADTypeCheckContext itself works as expected.
187+
@testset "ADTypeCheckContext" begin
188+
@model test_model() = x ~ Normal(0, 1)
189+
tm = test_model()
190+
adtypes = (
191+
Turing.AutoForwardDiff(),
192+
Turing.AutoReverseDiff(),
193+
# Don't need to test Mooncake as it doesn't use tracer types
194+
)
195+
for actual_adtype in adtypes
196+
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
197+
for expected_adtype in adtypes
198+
contextualised_tm = DynamicPPL.contextualize(
199+
tm, ADTypeCheckContext(expected_adtype, tm.context)
200+
)
201+
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
202+
if actual_adtype == expected_adtype
203+
# Check that this does not throw an error.
204+
Turing.sample(contextualised_tm, sampler, 2)
205+
else
206+
@test_throws AbstractWrongADBackendError Turing.sample(
207+
contextualised_tm, sampler, 2
208+
)
209+
end
210+
end
211+
end
212+
end
213+
end
214+
215+
@testset verbose = true "AD / ADTypeCheckContext" begin
216+
# This testset ensures that samplers or optimisers don't accidentally
217+
# override the AD backend set in it.
218+
@testset "adtype=$adtype" for adtype in ADTYPES
219+
seed = 123
220+
alg = HMC(0.1, 10; adtype=adtype)
221+
m = DynamicPPL.contextualize(
222+
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
223+
)
224+
# These will error if the adbackend being used is not the one set.
225+
sample(StableRNG(seed), m, alg, 10)
226+
maximum_likelihood(m; adtype=adtype)
227+
maximum_a_posteriori(m; adtype=adtype)
228+
end
229+
end
230+
231+
@testset verbose = true "AD / SamplingContext" begin
232+
# AD tests for gradient-based samplers need to be run with SamplingContext
233+
# because samplers can potentially use this to define custom behaviour in
234+
# the tilde-pipeline and thus change the code executed during model
235+
# evaluation.
236+
@testset "adtype=$adtype" for adtype in ADTYPES
237+
@testset "alg=$alg" for alg in [
238+
HMC(0.1, 10; adtype=adtype),
239+
HMCDA(0.8, 0.75; adtype=adtype),
240+
NUTS(1000, 0.8; adtype=adtype),
241+
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
242+
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
243+
]
244+
@info "Testing AD for $alg"
245+
246+
@testset "model=$(model.f)" for model in DEMO_MODELS
247+
rng = StableRNG(123)
248+
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
249+
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
250+
end
251+
end
252+
end
253+
end
254+
255+
@testset verbose = true "AD / GibbsContext" begin
256+
# Gibbs sampling also needs extra AD testing because the models are
257+
# executed with GibbsContext and a subsetted varinfo. (see e.g.
258+
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
259+
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
260+
# functions)
261+
@testset "adtype=$adtype" for adtype in ADTYPES
262+
@testset "model=$(model.f)" for model in DEMO_MODELS
263+
# All the demo models have variables `s` and `m`, so we'll pretend
264+
# that we're using a Gibbs sampler where both of them are sampled
265+
# with a gradient-based sampler (say HMC(0.1, 10)).
266+
# This means we need to construct one with only `s`, and one model with
267+
# only `m`.
268+
global_vi = DynamicPPL.VarInfo(model)
269+
@testset for varnames in ([@varname(s)], [@varname(m)])
270+
@info "Testing Gibbs AD with model=$(model.f), varnames=$varnames"
271+
conditioned_model = Turing.Inference.make_conditional(
272+
model, varnames, deepcopy(global_vi)
273+
)
274+
rng = StableRNG(123)
275+
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10)))
276+
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
277+
end
278+
end
279+
end
193280
end
281+
282+
@testset verbose = true "AD / Gibbs sampling" begin
283+
# Make sure that Gibbs sampling doesn't fall over when using AD.
284+
@testset "adtype=$adtype" for adtype in ADTYPES
285+
spl = Gibbs(
286+
@varname(s) => HMC(0.1, 10; adtype=adtype),
287+
@varname(m) => HMC(0.1, 10; adtype=adtype),
288+
)
289+
@testset "model=$(model.f)" for model in DEMO_MODELS
290+
@test sample(model, spl, 2) isa Any
291+
end
292+
end
293+
end
294+
295+
end # module

test/mcmc/Inference.jl

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module InferenceTests
22

33
using ..Models: gdemo_d, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5-
import ..ADUtils
65
using Distributions: Bernoulli, Beta, InverseGamma, Normal
76
using Distributions: sample
87
import DynamicPPL
@@ -17,8 +16,9 @@ import Mooncake
1716
using Test: @test, @test_throws, @testset
1817
using Turing
1918

20-
@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
21-
@info "Starting Inference.jl tests with $adbackend"
19+
@testset verbose = true "Testing Inference.jl" begin
20+
@info "Starting Inference.jl tests"
21+
2222
seed = 23
2323

2424
@testset "threaded sampling" begin
@@ -27,12 +27,12 @@ using Turing
2727
model = gdemo_default
2828

2929
samplers = (
30-
HMC(0.1, 7; adtype=adbackend),
30+
HMC(0.1, 7),
3131
PG(10),
3232
IS(),
3333
MH(),
34-
Gibbs(:s => PG(3), :m => HMC(0.4, 8; adtype=adbackend)),
35-
Gibbs(:s => HMC(0.1, 5; adtype=adbackend), :m => ESS()),
34+
Gibbs(:s => PG(3), :m => HMC(0.4, 8)),
35+
Gibbs(:s => HMC(0.1, 5), :m => ESS()),
3636
)
3737
for sampler in samplers
3838
Random.seed!(5)
@@ -44,7 +44,7 @@ using Turing
4444
@test chain1.value == chain2.value
4545
end
4646

47-
# Should also be stable with am explicit RNG
47+
# Should also be stable with an explicit RNG
4848
seed = 5
4949
rng = Random.MersenneTwister(seed)
5050
for sampler in samplers
@@ -61,27 +61,22 @@ using Turing
6161
# Smoke test for default sample call.
6262
@testset "gdemo_default" begin
6363
chain = sample(
64-
StableRNG(seed),
65-
gdemo_default,
66-
HMC(0.1, 7; adtype=adbackend),
67-
MCMCThreads(),
68-
1_000,
69-
4,
64+
StableRNG(seed), gdemo_default, HMC(0.1, 7), MCMCThreads(), 1_000, 4
7065
)
7166
check_gdemo(chain)
7267

7368
# run sampler: progress logging should be disabled and
7469
# it should return a Chains object
75-
sampler = Sampler(HMC(0.1, 7; adtype=adbackend))
70+
sampler = Sampler(HMC(0.1, 7))
7671
chains = sample(StableRNG(seed), gdemo_default, sampler, MCMCThreads(), 10, 4)
7772
@test chains isa MCMCChains.Chains
7873
end
7974
end
8075

8176
@testset "chain save/resume" begin
82-
alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend)
77+
alg1 = HMCDA(1000, 0.65, 0.15)
8378
alg2 = PG(20)
84-
alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4; adtype=adbackend))
79+
alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4))
8580

8681
chn1 = sample(StableRNG(seed), gdemo_default, alg1, 10_000; save_state=true)
8782
check_gdemo(chn1)
@@ -260,7 +255,7 @@ using Turing
260255

261256
smc = SMC()
262257
pg = PG(10)
263-
gibbs = Gibbs(:p => HMC(0.2, 3; adtype=adbackend), :x => PG(10))
258+
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))
264259

265260
chn_s = sample(StableRNG(seed), testbb(obs), smc, 200)
266261
chn_p = sample(StableRNG(seed), testbb(obs), pg, 200)
@@ -273,22 +268,17 @@ using Turing
273268

274269
@testset "forbid global" begin
275270
xs = [1.5 2.0]
276-
# xx = 1
277271

278272
@model function fggibbstest(xs)
279273
s ~ InverseGamma(2, 3)
280274
m ~ Normal(0, sqrt(s))
281-
# xx ~ Normal(m, sqrt(s)) # this is illegal
282-
283275
for i in 1:length(xs)
284276
xs[i] ~ Normal(m, sqrt(s))
285-
# for xx in xs
286-
# xx ~ Normal(m, sqrt(s))
287277
end
288278
return s, m
289279
end
290280

291-
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8; adtype=adbackend))
281+
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8))
292282
chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2)
293283
end
294284

@@ -353,7 +343,7 @@ using Turing
353343
)
354344
end
355345

356-
# TODO(mhauru) What is this testing? Why does it not use the looped-over adbackend?
346+
# TODO(mhauru) What is this testing? Why does it use a different adbackend?
357347
@testset "new interface" begin
358348
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
359349

@@ -382,9 +372,7 @@ using Turing
382372
end
383373
end
384374

385-
chain = sample(
386-
StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10; adtype=adbackend), 4000
387-
)
375+
chain = sample(StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10), 4000)
388376
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6])
389377
end
390378

@@ -415,7 +403,7 @@ using Turing
415403
end
416404

417405
@testset "sample" begin
418-
alg = Gibbs(:m => HMC(0.2, 3; adtype=adbackend), :s => PG(10))
406+
alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10))
419407
chn = sample(StableRNG(seed), gdemo_default, alg, 10)
420408
end
421409

@@ -427,7 +415,7 @@ using Turing
427415
return s, m
428416
end
429417

430-
alg = HMC(0.01, 5; adtype=adbackend)
418+
alg = HMC(0.01, 5)
431419
x = randn(100)
432420
res = sample(StableRNG(seed), vdemo1(x), alg, 10)
433421

@@ -442,7 +430,7 @@ using Turing
442430

443431
# Vector assumptions
444432
N = 10
445-
alg = HMC(0.2, 4; adtype=adbackend)
433+
alg = HMC(0.2, 4)
446434

447435
@model function vdemo3()
448436
x = Vector{Real}(undef, N)
@@ -497,7 +485,7 @@ using Turing
497485
return s, m
498486
end
499487

500-
alg = HMC(0.01, 5; adtype=adbackend)
488+
alg = HMC(0.01, 5)
501489
x = randn(100)
502490
res = sample(StableRNG(seed), vdemo1(x), alg, 10)
503491

@@ -507,12 +495,12 @@ using Turing
507495
end
508496

509497
D = 2
510-
alg = HMC(0.01, 5; adtype=adbackend)
498+
alg = HMC(0.01, 5)
511499
res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10)
512500

513501
# Vector assumptions
514502
N = 10
515-
alg = HMC(0.2, 4; adtype=adbackend)
503+
alg = HMC(0.2, 4)
516504

517505
@model function vdemo3()
518506
x = Vector{Real}(undef, N)
@@ -559,7 +547,7 @@ using Turing
559547

560548
@testset "Type parameters" begin
561549
N = 10
562-
alg = HMC(0.01, 5; adtype=adbackend)
550+
alg = HMC(0.01, 5)
563551
x = randn(1000)
564552
@model function vdemo1(::Type{T}=Float64) where {T}
565553
x = Vector{T}(undef, N)

0 commit comments

Comments
 (0)