Skip to content

Commit 803d2f5

Browse files
authored
Check that the correct AD backend is being used (#2291)
* Add ADTypeCheckContext * Check ADType use in optimisation * Use ADTypeCheckContext with hmc tests * using A: A instead of import A * More robust ADTypeCheckContext checks for Zygote
1 parent 4766fdd commit 803d2f5

File tree

4 files changed

+294
-1
lines changed

4 files changed

+294
-1
lines changed

test/mcmc/hmc.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module HMCTests
22

33
using ..Models: gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
45
#using ..Models: gdemo
56
using ..NumericalTests: check_gdemo, check_numerical
67
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
@@ -321,6 +322,15 @@ using Turing
321322
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
322323
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
323324
end
325+
326+
@testset "Check ADType" begin
327+
alg = HMC(0.1, 10; adtype=adbackend)
328+
m = DynamicPPL.contextualize(
329+
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
330+
)
331+
# These will error if the adbackend being used is not the one set.
332+
sample(rng, m, alg, 10)
333+
end
324334
end
325335

326336
end

test/optimisation/Optimisation.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module OptimisationTests
22

33
using ..Models: gdemo, gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
45
using Distributions
56
using Distributions.FillArrays: Zeros
67
using DynamicPPL: DynamicPPL
@@ -140,7 +141,6 @@ using Turing
140141
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
141142
)
142143
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
143-
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
144144
m4 = maximum_likelihood(
145145
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
146146
)
@@ -616,6 +616,18 @@ using Turing
616616
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
617617
@assert get(result, :c) == (; :c => Array{Float64}[])
618618
end
619+
620+
@testset "ADType" begin
621+
Random.seed!(222)
622+
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
623+
m = DynamicPPL.contextualize(
624+
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
625+
)
626+
# These will error if the adbackend being used is not the one set.
627+
maximum_likelihood(m; adtype=adbackend)
628+
maximum_a_posteriori(m; adtype=adbackend)
629+
end
630+
end
619631
end
620632

621633
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Turing
77

88
include(pkgdir(Turing) * "/test/test_utils/models.jl")
99
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
10+
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")
1011

1112
Turing.setprogress!(false)
1213

test/test_utils/ad_utils.jl

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
module ADUtils
2+
3+
using ForwardDiff: ForwardDiff
4+
using ReverseDiff: ReverseDiff
5+
using Test: Test
6+
using Tracker: Tracker
7+
using Turing: Turing
8+
using Turing: DynamicPPL
9+
using Zygote: Zygote
10+
11+
export ADTypeCheckContext
12+
13+
"""Element types that are always valid for a VarInfo regardless of ADType."""
14+
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
15+
16+
"""A dictionary mapping ADTypes to the element types they use."""
17+
const eltypes_by_adtype = Dict(
18+
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
19+
Turing.AutoReverseDiff => (
20+
ReverseDiff.TrackedArray,
21+
ReverseDiff.TrackedMatrix,
22+
ReverseDiff.TrackedReal,
23+
ReverseDiff.TrackedStyle,
24+
ReverseDiff.TrackedType,
25+
ReverseDiff.TrackedVecOrMat,
26+
ReverseDiff.TrackedVector,
27+
),
28+
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
29+
# two by element type. However, we have other checks for Zygote, see check_adtype.
30+
Turing.AutoZygote => (Zygote.Dual,),
31+
Turing.AutoTracker => (
32+
Tracker.Tracked,
33+
Tracker.TrackedArray,
34+
Tracker.TrackedMatrix,
35+
Tracker.TrackedReal,
36+
Tracker.TrackedStyle,
37+
Tracker.TrackedVecOrMat,
38+
Tracker.TrackedVector,
39+
),
40+
)
41+
42+
"""
43+
AbstractWrongADBackendError
44+
45+
An abstract error thrown when we seem to be using a different AD backend than expected.
46+
"""
47+
abstract type AbstractWrongADBackendError <: Exception end
48+
49+
"""
50+
WrongADBackendError
51+
52+
An error thrown when we seem to be using a different AD backend than expected.
53+
"""
54+
struct WrongADBackendError <: AbstractWrongADBackendError
55+
actual_adtype::Type
56+
expected_adtype::Type
57+
end
58+
59+
function Base.showerror(io::IO, e::WrongADBackendError)
60+
return print(
61+
io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead."
62+
)
63+
end
64+
65+
"""
66+
IncompatibleADTypeError
67+
68+
An error thrown when an element type is encountered that is unexpected for the given ADType.
69+
"""
70+
struct IncompatibleADTypeError <: AbstractWrongADBackendError
71+
valtype::Type
72+
adtype::Type
73+
end
74+
75+
function Base.showerror(io::IO, e::IncompatibleADTypeError)
76+
return print(
77+
io,
78+
"Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)",
79+
)
80+
end
81+
82+
"""
83+
ADTypeCheckContext{ADType,ChildContext}
84+
85+
A context for checking that the expected ADType is being used.
86+
87+
Evaluating a model with this context will check that the types of values in a `VarInfo` are
88+
compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError`
89+
is thrown.
90+
91+
For instance, evaluating a model with
92+
`ADTypeCheckContext(AutoForwardDiff(), child_context)`
93+
would throw an error if within the model a type associated with e.g. ReverseDiff was
94+
encountered.
95+
96+
As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
97+
"""
98+
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
99+
DynamicPPL.AbstractContext
100+
child::ChildContext
101+
102+
function ADTypeCheckContext(adbackend, child)
103+
adtype = adbackend isa Type ? adbackend : typeof(adbackend)
104+
if !any(adtype <: k for k in keys(eltypes_by_adtype))
105+
throw(ArgumentError("Unsupported ADType: $adtype"))
106+
end
107+
return new{adtype,typeof(child)}(child)
108+
end
109+
end
110+
111+
adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType
112+
113+
DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
114+
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
115+
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
116+
return ADTypeCheckContext(adtype(c), child)
117+
end
118+
119+
"""
120+
valid_eltypes(context::ADTypeCheckContext)
121+
122+
Return the element types that are valid for the ADType of `context` as a tuple.
123+
"""
124+
function valid_eltypes(context::ADTypeCheckContext)
125+
context_at = adtype(context)
126+
for at in keys(eltypes_by_adtype)
127+
if context_at <: at
128+
return (eltypes_by_adtype[at]..., always_valid_eltypes...)
129+
end
130+
end
131+
# This should never be reached due to the check in the inner constructor.
132+
throw(ArgumentError("Unsupported ADType: $(adtype(context))"))
133+
end
134+
135+
"""
136+
check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)
137+
138+
Check that the element types in `vi` are compatible with the ADType of `context`.
139+
140+
When Zygote is being used, we also more explicitly check that `adtype(context)` is
141+
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
142+
discriminate between the two based on element type alone. This function will still fail to
143+
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.
144+
145+
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
146+
`WrongADBackendError` if Zygote is used unexpectedly.
147+
"""
148+
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
149+
Zygote.hook(vi) do _
150+
if !(adtype(context) <: Turing.AutoZygote)
151+
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
152+
end
153+
end
154+
155+
valids = valid_eltypes(context)
156+
for val in vi[:]
157+
valtype = typeof(val)
158+
if !any(valtype .<: valids)
159+
throw(IncompatibleADTypeError(valtype, adtype(context)))
160+
end
161+
end
162+
return nothing
163+
end
164+
165+
# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
166+
# context, and then call check_adtype on the result before returning the results from the
167+
# child context.
168+
169+
function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi)
170+
value, logp, vi = DynamicPPL.tilde_assume(
171+
DynamicPPL.childcontext(context), right, vn, vi
172+
)
173+
check_adtype(context, vi)
174+
return value, logp, vi
175+
end
176+
177+
function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi)
178+
value, logp, vi = DynamicPPL.tilde_assume(
179+
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
180+
)
181+
check_adtype(context, vi)
182+
return value, logp, vi
183+
end
184+
185+
function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi)
186+
logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi)
187+
check_adtype(context, vi)
188+
return logp, vi
189+
end
190+
191+
function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
192+
logp, vi = DynamicPPL.tilde_observe(
193+
DynamicPPL.childcontext(context), sampler, right, left, vi
194+
)
195+
check_adtype(context, vi)
196+
return logp, vi
197+
end
198+
199+
function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, vn, vi)
200+
value, logp, vi = DynamicPPL.dot_tilde_assume(
201+
DynamicPPL.childcontext(context), right, left, vn, vi
202+
)
203+
check_adtype(context, vi)
204+
return value, logp, vi
205+
end
206+
207+
function DynamicPPL.dot_tilde_assume(
208+
rng, context::ADTypeCheckContext, sampler, right, left, vn, vi
209+
)
210+
value, logp, vi = DynamicPPL.dot_tilde_assume(
211+
rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi
212+
)
213+
check_adtype(context, vi)
214+
return value, logp, vi
215+
end
216+
217+
function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left, vi)
218+
logp, vi = DynamicPPL.dot_tilde_observe(
219+
DynamicPPL.childcontext(context), right, left, vi
220+
)
221+
check_adtype(context, vi)
222+
return logp, vi
223+
end
224+
225+
function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
226+
logp, vi = DynamicPPL.dot_tilde_observe(
227+
DynamicPPL.childcontext(context), sampler, right, left, vi
228+
)
229+
check_adtype(context, vi)
230+
return logp, vi
231+
end
232+
233+
# Check that the ADTypeCheckContext works as expected.
234+
Test.@testset "ADTypeCheckContext" begin
235+
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
236+
tm = test_model()
237+
adtypes = (
238+
Turing.AutoForwardDiff(),
239+
Turing.AutoReverseDiff(),
240+
Turing.AutoZygote(),
241+
Turing.AutoTracker(),
242+
)
243+
for actual_adtype in adtypes
244+
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
245+
for expected_adtype in adtypes
246+
if (
247+
actual_adtype == Turing.AutoForwardDiff() &&
248+
expected_adtype == Turing.AutoZygote()
249+
)
250+
# TODO(mhauru) We are currently unable to check this case.
251+
continue
252+
end
253+
contextualised_tm = DynamicPPL.contextualize(
254+
tm, ADTypeCheckContext(expected_adtype, tm.context)
255+
)
256+
Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
257+
if actual_adtype == expected_adtype
258+
# Check that this does not throw an error.
259+
Turing.sample(contextualised_tm, sampler, 2)
260+
else
261+
Test.@test_throws AbstractWrongADBackendError Turing.sample(
262+
contextualised_tm, sampler, 2
263+
)
264+
end
265+
end
266+
end
267+
end
268+
end
269+
270+
end

0 commit comments

Comments
 (0)