Skip to content

Commit 83ed0ad

Browse files
committed
Move ADTypeCheckContexts and optimisation tests to ad.jl as well
1 parent 5d654c7 commit 83ed0ad

File tree

8 files changed

+35
-50
lines changed

8 files changed

+35
-50
lines changed

test/ad.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,39 @@ ADTYPES = [
181181
Turing.AutoMooncake(; config=nothing),
182182
]
183183

184+
# Check that ADTypeCheckContext itself works as expected.
185+
@testset "ADTypeCheckContext" begin
186+
@model test_model() = x ~ Normal(0, 1)
187+
tm = test_model()
188+
adtypes = (
189+
Turing.AutoForwardDiff(),
190+
Turing.AutoReverseDiff(),
191+
# TODO: Mooncake
192+
# Turing.AutoMooncake(config=nothing),
193+
)
194+
for actual_adtype in adtypes
195+
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
196+
for expected_adtype in adtypes
197+
contextualised_tm = DynamicPPL.contextualize(
198+
tm, ADTypeCheckContext(expected_adtype, tm.context)
199+
)
200+
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
201+
if actual_adtype == expected_adtype
202+
# Check that this does not throw an error.
203+
Turing.sample(contextualised_tm, sampler, 2)
204+
else
205+
@test_throws AbstractWrongADBackendError Turing.sample(
206+
contextualised_tm, sampler, 2
207+
)
208+
end
209+
end
210+
end
211+
end
212+
end
213+
184214
@testset verbose = true "AD / ADTypeCheckContext" begin
185-
# This testset ensures that samplers don't accidentally override the AD
186-
# backend set in it.
215+
# This testset ensures that samplers or optimisers don't accidentally
216+
# override the AD backend set in it.
187217
@testset "Check ADType" begin
188218
seed = 123
189219
alg = HMC(0.1, 10; adtype=adtype)
@@ -192,6 +222,8 @@ ADTYPES = [
192222
)
193223
# These will error if the adbackend being used is not the one set.
194224
sample(StableRNG(seed), m, alg, 10)
225+
maximum_likelihood(m; adtype=adbackend)
226+
maximum_a_posteriori(m; adtype=adbackend)
195227
end
196228
end
197229

test/mcmc/Inference.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module InferenceTests
22

33
using ..Models: gdemo_d, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5-
import ..ADUtils
65
using Distributions: Bernoulli, Beta, InverseGamma, Normal
76
using Distributions: sample
87
import DynamicPPL

test/mcmc/abstractmcmc.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module AbstractMCMCTests
22

3-
import ..ADUtils
4-
using AbstractMCMC: AbstractMCMC
3+
using AbsttractMCMC: AbstractMCMC
54
using AdvancedMH: AdvancedMH
65
using Distributions: sample
76
using Distributions.FillArrays: Zeros

test/mcmc/gibbs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using ..NumericalTests:
77
check_gdemo,
88
check_numerical,
99
two_sample_test
10-
import ..ADUtils
1110
import Combinatorics
1211
using AbstractMCMC: AbstractMCMC
1312
using Distributions: InverseGamma, Normal

test/mcmc/hmc.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
module HMCTests
22

33
using ..Models: gdemo_default
4-
using ..ADUtils: ADTypeCheckContext
54
using ..NumericalTests: check_gdemo, check_numerical
6-
import ..ADUtils
75
using Bijectors: Bijectors
86
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
97
using DynamicPPL: DynamicPPL, Sampler

test/mcmc/sghmc.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module SGHMCTests
22

33
using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
5-
import ..ADUtils
65
using DynamicPPL.TestUtils.AD: run_ad
76
using DynamicPPL.TestUtils: DEMO_MODELS
87
using DynamicPPL: DynamicPPL

test/optimisation/Optimisation.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module OptimisationTests
22

33
using ..Models: gdemo, gdemo_default
4-
using ..ADUtils: ADUtils
54
using Distributions
65
using Distributions.FillArrays: Zeros
76
using DynamicPPL: DynamicPPL
@@ -624,16 +623,6 @@ using Turing
624623
@assert get(result, :c) == (; :c => Array{Float64}[])
625624
end
626625

627-
@testset "ADType test with $adbackend" for adbackend in ADUtils.adbackends
628-
Random.seed!(222)
629-
m = DynamicPPL.contextualize(
630-
gdemo_default, ADUtils.ADTypeCheckContext(adbackend, gdemo_default.context)
631-
)
632-
# These will error if the adbackend being used is not the one set.
633-
maximum_likelihood(m; adtype=adbackend)
634-
maximum_a_posteriori(m; adtype=adbackend)
635-
end
636-
637626
@testset "Collinear coeftable" begin
638627
xs = [-1.0, 0.0, 1.0]
639628
ys = [0.0, 0.0, 0.0]

test/test_utils/test_utils.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,4 @@ using Test: @test, @testset, @test_throws
88
using Turing: Turing
99
using Turing: DynamicPPL
1010

11-
# Check that the ADTypeCheckContext works as expected.
12-
@testset "ADTypeCheckContext" begin
13-
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
14-
tm = test_model()
15-
adtypes = (
16-
Turing.AutoForwardDiff(),
17-
Turing.AutoReverseDiff(),
18-
# TODO: Mooncake
19-
# Turing.AutoMooncake(config=nothing),
20-
)
21-
for actual_adtype in adtypes
22-
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
23-
for expected_adtype in adtypes
24-
contextualised_tm = DynamicPPL.contextualize(
25-
tm, ADTypeCheckContext(expected_adtype, tm.context)
26-
)
27-
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
28-
if actual_adtype == expected_adtype
29-
# Check that this does not throw an error.
30-
Turing.sample(contextualised_tm, sampler, 2)
31-
else
32-
@test_throws AbstractWrongADBackendError Turing.sample(
33-
contextualised_tm, sampler, 2
34-
)
35-
end
36-
end
37-
end
38-
end
39-
end
40-
4111
end

0 commit comments

Comments
 (0)