Skip to content

Commit 0d2955f

Browse files
committed
Finish AD tests
1 parent be08026 commit 0d2955f

File tree

4 files changed

+244
-261
lines changed

4 files changed

+244
-261
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
34
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
45
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"

test/ad.jl

Lines changed: 189 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,200 @@ using DynamicPPL.TestUtils.AD: run_ad
77
using StableRNGs: StableRNG
88
using Test
99
using ..Models: gdemo_default
10-
using ..ADUtils: ADTypeCheckContext, adbackends
10+
11+
"""Element types that are always valid for a VarInfo regardless of ADType."""
12+
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
13+
14+
"""A dictionary mapping ADTypes to the element types they use."""
15+
const eltypes_by_adtype = Dict(
16+
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
17+
Turing.AutoReverseDiff => (
18+
ReverseDiff.TrackedArray,
19+
ReverseDiff.TrackedMatrix,
20+
ReverseDiff.TrackedReal,
21+
ReverseDiff.TrackedStyle,
22+
ReverseDiff.TrackedType,
23+
ReverseDiff.TrackedVecOrMat,
24+
ReverseDiff.TrackedVector,
25+
),
26+
Turing.AutoMooncake => (Mooncake.CoDual,),
27+
)
28+
29+
"""
30+
AbstractWrongADBackendError
31+
32+
An abstract error thrown when we seem to be using a different AD backend than expected.
33+
"""
34+
abstract type AbstractWrongADBackendError <: Exception end
35+
36+
"""
37+
WrongADBackendError
38+
39+
An error thrown when we seem to be using a different AD backend than expected.
40+
"""
41+
struct WrongADBackendError <: AbstractWrongADBackendError
42+
actual_adtype::Type
43+
expected_adtype::Type
44+
end
45+
46+
function Base.showerror(io::IO, e::WrongADBackendError)
47+
return print(
48+
io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead."
49+
)
50+
end
51+
52+
"""
53+
IncompatibleADTypeError
54+
55+
An error thrown when an element type is encountered that is unexpected for the given ADType.
56+
"""
57+
struct IncompatibleADTypeError <: AbstractWrongADBackendError
58+
valtype::Type
59+
adtype::Type
60+
end
61+
62+
function Base.showerror(io::IO, e::IncompatibleADTypeError)
63+
return print(
64+
io,
65+
"Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)",
66+
)
67+
end
68+
69+
"""
70+
ADTypeCheckContext{ADType,ChildContext}
71+
72+
A context for checking that the expected ADType is being used.
73+
74+
Evaluating a model with this context will check that the types of values in a `VarInfo` are
75+
compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError`
76+
is thrown.
77+
78+
For instance, evaluating a model with
79+
`ADTypeCheckContext(AutoForwardDiff(), child_context)`
80+
would throw an error if within the model a type associated with e.g. ReverseDiff was
81+
encountered.
82+
83+
"""
84+
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
85+
DynamicPPL.AbstractContext
86+
child::ChildContext
87+
88+
function ADTypeCheckContext(adbackend, child)
89+
adtype = adbackend isa Type ? adbackend : typeof(adbackend)
90+
if !any(adtype <: k for k in keys(eltypes_by_adtype))
91+
throw(ArgumentError("Unsupported ADType: $adtype"))
92+
end
93+
return new{adtype,typeof(child)}(child)
94+
end
95+
end
96+
97+
adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType
98+
99+
DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
100+
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
101+
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
102+
return ADTypeCheckContext(adtype(c), child)
103+
end
104+
105+
"""
106+
valid_eltypes(context::ADTypeCheckContext)
107+
108+
Return the element types that are valid for the ADType of `context` as a tuple.
109+
"""
110+
function valid_eltypes(context::ADTypeCheckContext)
111+
context_at = adtype(context)
112+
for at in keys(eltypes_by_adtype)
113+
if context_at <: at
114+
return (eltypes_by_adtype[at]..., always_valid_eltypes...)
115+
end
116+
end
117+
# This should never be reached due to the check in the inner constructor.
118+
throw(ArgumentError("Unsupported ADType: $(adtype(context))"))
119+
end
120+
121+
"""
122+
check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)
123+
124+
Check that the element types in `vi` are compatible with the ADType of `context`.
125+
126+
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
127+
"""
128+
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
129+
valids = valid_eltypes(context)
130+
for val in vi[:]
131+
valtype = typeof(val)
132+
if !any(valtype .<: valids)
133+
throw(IncompatibleADTypeError(valtype, adtype(context)))
134+
end
135+
end
136+
return nothing
137+
end
138+
139+
# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
140+
# context, and then call check_adtype on the result before returning the results from the
141+
# child context.
142+
143+
function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi)
144+
value, logp, vi = DynamicPPL.tilde_assume(
145+
DynamicPPL.childcontext(context), right, vn, vi
146+
)
147+
check_adtype(context, vi)
148+
return value, logp, vi
149+
end
150+
151+
function DynamicPPL.tilde_assume(
152+
rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi
153+
)
154+
value, logp, vi = DynamicPPL.tilde_assume(
155+
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
156+
)
157+
check_adtype(context, vi)
158+
return value, logp, vi
159+
end
160+
161+
function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi)
162+
logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi)
163+
check_adtype(context, vi)
164+
return logp, vi
165+
end
166+
167+
function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
168+
logp, vi = DynamicPPL.tilde_observe(
169+
DynamicPPL.childcontext(context), sampler, right, left, vi
170+
)
171+
check_adtype(context, vi)
172+
return logp, vi
173+
end
174+
175+
"""
176+
All the ADTypes on which we want to run the tests.
177+
"""
178+
ADTYPES = [
179+
Turing.AutoForwardDiff(),
180+
Turing.AutoReverseDiff(; compile=false),
181+
Turing.AutoMooncake(; config=nothing),
182+
]
183+
184+
@testset verbose = true "AD / ADTypeCheckContext" begin
185+
# This testset ensures that samplers don't accidentally override the AD
186+
# backend set in it.
187+
@testset "Check ADType" begin
188+
seed = 123
189+
alg = HMC(0.1, 10; adtype=adtype)
190+
m = DynamicPPL.contextualize(
191+
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
192+
)
193+
# These will error if the adbackend being used is not the one set.
194+
sample(StableRNG(seed), m, alg, 10)
195+
end
196+
end
11197

12198
@testset verbose = true "AD / SamplingContext" begin
13199
# AD tests for gradient-based samplers need to be run with SamplingContext
14200
# because samplers can potentially use this to define custom behaviour in
15201
# the tilde-pipeline and thus change the code executed during model
16202
# evaluation.
17-
@testset "adtype=$adtype" for adtype in adbackends
203+
@testset "adtype=$adtype" for adtype in ADTYPES
18204
@testset "alg=$alg" for alg in [
19205
HMC(0.1, 10; adtype=adtype),
20206
HMCDA(0.8, 0.75; adtype=adtype),
@@ -30,16 +216,6 @@ using ..ADUtils: ADTypeCheckContext, adbackends
30216
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
31217
end
32218
end
33-
34-
@testset "Check ADType" begin
35-
seed = 123
36-
alg = HMC(0.1, 10; adtype=adtype)
37-
m = DynamicPPL.contextualize(
38-
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
39-
)
40-
# These will error if the adbackend being used is not the one set.
41-
sample(StableRNG(seed), m, alg, 10)
42-
end
43219
end
44220
end
45221

@@ -49,7 +225,7 @@ end
49225
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
50226
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
51227
# functions)
52-
@testset "adtype=$adtype" for adtype in adbackends
228+
@testset "adtype=$adtype" for adtype in ADTYPES
53229
@testset "model=$(model.f)" for model in DEMO_MODELS
54230
# All the demo models have variables `s` and `m`, so we'll pretend
55231
# that we're using a Gibbs sampler where both of them are sampled

test/runtests.jl

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ import Turing
99
# Fix the global Random.seed for reproducibility.
1010
seed!(23)
1111

12-
include(pkgdir(Turing) * "/test/test_utils/models.jl")
13-
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
14-
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")
12+
include("test_utils/models.jl")
13+
include("test_utils/numerical_tests.jl")
1514

1615
Turing.setprogress!(false)
1716
included_paths, excluded_paths = parse_args(ARGS)
@@ -30,63 +29,63 @@ macro timeit_include(path::AbstractString)
3029
end
3130

3231
@testset "Turing" verbose = true begin
33-
@testset "Test utils" begin
34-
@timeit_include("test_utils/test_utils.jl")
35-
end
36-
37-
@testset "Aqua" begin
38-
@timeit_include("Aqua.jl")
39-
end
32+
# @testset "Test utils" begin
33+
# @timeit_include("test_utils/test_utils.jl")
34+
# end
35+
#
36+
# @testset "Aqua" begin
37+
# @timeit_include("Aqua.jl")
38+
# end
4039

4140
@testset "AD" verbose = true begin
4241
@timeit_include("ad.jl")
4342
end
4443

45-
@testset "essential" verbose = true begin
46-
@timeit_include("essential/container.jl")
47-
end
48-
49-
@testset "samplers (without AD)" verbose = true begin
50-
@timeit_include("mcmc/particle_mcmc.jl")
51-
@timeit_include("mcmc/emcee.jl")
52-
@timeit_include("mcmc/ess.jl")
53-
@timeit_include("mcmc/is.jl")
54-
end
55-
56-
@timeit TIMEROUTPUT "inference" begin
57-
@testset "inference with samplers" verbose = true begin
58-
@timeit_include("mcmc/gibbs.jl")
59-
@timeit_include("mcmc/hmc.jl")
60-
@timeit_include("mcmc/Inference.jl")
61-
@timeit_include("mcmc/sghmc.jl")
62-
@timeit_include("mcmc/abstractmcmc.jl")
63-
@timeit_include("mcmc/mh.jl")
64-
@timeit_include("ext/dynamichmc.jl")
65-
@timeit_include("mcmc/repeat_sampler.jl")
66-
end
67-
68-
@testset "variational algorithms" begin
69-
@timeit_include("variational/advi.jl")
70-
end
71-
72-
@testset "mode estimation" verbose = true begin
73-
@timeit_include("optimisation/Optimisation.jl")
74-
@timeit_include("ext/OptimInterface.jl")
75-
end
76-
end
77-
78-
@testset "variational optimisers" begin
79-
@timeit_include("variational/optimisers.jl")
80-
end
81-
82-
@testset "stdlib" verbose = true begin
83-
@timeit_include("stdlib/distributions.jl")
84-
@timeit_include("stdlib/RandomMeasures.jl")
85-
end
86-
87-
@testset "utilities" begin
88-
@timeit_include("mcmc/utilities.jl")
89-
end
44+
# @testset "essential" verbose = true begin
45+
# @timeit_include("essential/container.jl")
46+
# end
47+
#
48+
# @testset "samplers (without AD)" verbose = true begin
49+
# @timeit_include("mcmc/particle_mcmc.jl")
50+
# @timeit_include("mcmc/emcee.jl")
51+
# @timeit_include("mcmc/ess.jl")
52+
# @timeit_include("mcmc/is.jl")
53+
# end
54+
#
55+
# @timeit TIMEROUTPUT "inference" begin
56+
# @testset "inference with samplers" verbose = true begin
57+
# @timeit_include("mcmc/gibbs.jl")
58+
# @timeit_include("mcmc/hmc.jl")
59+
# @timeit_include("mcmc/Inference.jl")
60+
# @timeit_include("mcmc/sghmc.jl")
61+
# @timeit_include("mcmc/abstractmcmc.jl")
62+
# @timeit_include("mcmc/mh.jl")
63+
# @timeit_include("ext/dynamichmc.jl")
64+
# @timeit_include("mcmc/repeat_sampler.jl")
65+
# end
66+
#
67+
# @testset "variational algorithms" begin
68+
# @timeit_include("variational/advi.jl")
69+
# end
70+
#
71+
# @testset "mode estimation" verbose = true begin
72+
# @timeit_include("optimisation/Optimisation.jl")
73+
# @timeit_include("ext/OptimInterface.jl")
74+
# end
75+
# end
76+
#
77+
# @testset "variational optimisers" begin
78+
# @timeit_include("variational/optimisers.jl")
79+
# end
80+
#
81+
# @testset "stdlib" verbose = true begin
82+
# @timeit_include("stdlib/distributions.jl")
83+
# @timeit_include("stdlib/RandomMeasures.jl")
84+
# end
85+
#
86+
# @testset "utilities" begin
87+
# @timeit_include("mcmc/utilities.jl")
88+
# end
9089
end
9190

9291
show(TIMEROUTPUT; compact=true, sortby=:firstexec)

0 commit comments

Comments
 (0)