@@ -7,14 +7,200 @@ using DynamicPPL.TestUtils.AD: run_ad
7
7
using StableRNGs: StableRNG
8
8
using Test
9
9
using .. Models: gdemo_default
10
- using .. ADUtils: ADTypeCheckContext, adbackends
10
+
11
+ """ Element types that are always valid for a VarInfo regardless of ADType."""
12
+ const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
13
+
14
+ """ A dictionary mapping ADTypes to the element types they use."""
15
+ const eltypes_by_adtype = Dict (
16
+ Turing. AutoForwardDiff => (ForwardDiff. Dual,),
17
+ Turing. AutoReverseDiff => (
18
+ ReverseDiff. TrackedArray,
19
+ ReverseDiff. TrackedMatrix,
20
+ ReverseDiff. TrackedReal,
21
+ ReverseDiff. TrackedStyle,
22
+ ReverseDiff. TrackedType,
23
+ ReverseDiff. TrackedVecOrMat,
24
+ ReverseDiff. TrackedVector,
25
+ ),
26
+ Turing. AutoMooncake => (Mooncake. CoDual,),
27
+ )
28
+
29
+ """
30
+ AbstractWrongADBackendError
31
+
32
+ An abstract error thrown when we seem to be using a different AD backend than expected.
33
+ """
34
+ abstract type AbstractWrongADBackendError <: Exception end
35
+
36
+ """
37
+ WrongADBackendError
38
+
39
+ An error thrown when we seem to be using a different AD backend than expected.
40
+ """
41
+ struct WrongADBackendError <: AbstractWrongADBackendError
42
+ actual_adtype:: Type
43
+ expected_adtype:: Type
44
+ end
45
+
46
+ function Base. showerror (io:: IO , e:: WrongADBackendError )
47
+ return print (
48
+ io, " Expected to use $(e. expected_adtype) , but using $(e. actual_adtype) instead."
49
+ )
50
+ end
51
+
52
+ """
53
+ IncompatibleADTypeError
54
+
55
+ An error thrown when an element type is encountered that is unexpected for the given ADType.
56
+ """
57
+ struct IncompatibleADTypeError <: AbstractWrongADBackendError
58
+ valtype:: Type
59
+ adtype:: Type
60
+ end
61
+
62
+ function Base. showerror (io:: IO , e:: IncompatibleADTypeError )
63
+ return print (
64
+ io,
65
+ " Incompatible ADType: Did not expect element of type $(e. valtype) with $(e. adtype) " ,
66
+ )
67
+ end
68
+
69
+ """
70
+ ADTypeCheckContext{ADType,ChildContext}
71
+
72
+ A context for checking that the expected ADType is being used.
73
+
74
+ Evaluating a model with this context will check that the types of values in a `VarInfo` are
75
+ compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError`
76
+ is thrown.
77
+
78
+ For instance, evaluating a model with
79
+ `ADTypeCheckContext(AutoForwardDiff(), child_context)`
80
+ would throw an error if within the model a type associated with e.g. ReverseDiff was
81
+ encountered.
82
+
83
+ """
84
+ struct ADTypeCheckContext{ADType,ChildContext<: DynamicPPL.AbstractContext } < :
85
+ DynamicPPL. AbstractContext
86
+ child:: ChildContext
87
+
88
+ function ADTypeCheckContext (adbackend, child)
89
+ adtype = adbackend isa Type ? adbackend : typeof (adbackend)
90
+ if ! any (adtype <: k for k in keys (eltypes_by_adtype))
91
+ throw (ArgumentError (" Unsupported ADType: $adtype " ))
92
+ end
93
+ return new {adtype,typeof(child)} (child)
94
+ end
95
+ end
96
+
97
+ adtype (_:: ADTypeCheckContext{ADType} ) where {ADType} = ADType
98
+
99
+ DynamicPPL. NodeTrait (:: ADTypeCheckContext ) = DynamicPPL. IsParent ()
100
+ DynamicPPL. childcontext (c:: ADTypeCheckContext ) = c. child
101
+ function DynamicPPL. setchildcontext (c:: ADTypeCheckContext , child)
102
+ return ADTypeCheckContext (adtype (c), child)
103
+ end
104
+
105
+ """
106
+ valid_eltypes(context::ADTypeCheckContext)
107
+
108
+ Return the element types that are valid for the ADType of `context` as a tuple.
109
+ """
110
+ function valid_eltypes (context:: ADTypeCheckContext )
111
+ context_at = adtype (context)
112
+ for at in keys (eltypes_by_adtype)
113
+ if context_at <: at
114
+ return (eltypes_by_adtype[at]. .. , always_valid_eltypes... )
115
+ end
116
+ end
117
+ # This should never be reached due to the check in the inner constructor.
118
+ throw (ArgumentError (" Unsupported ADType: $(adtype (context)) " ))
119
+ end
120
+
121
+ """
122
+ check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)
123
+
124
+ Check that the element types in `vi` are compatible with the ADType of `context`.
125
+
126
+ Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
127
+ """
128
+ function check_adtype (context:: ADTypeCheckContext , vi:: DynamicPPL.AbstractVarInfo )
129
+ valids = valid_eltypes (context)
130
+ for val in vi[:]
131
+ valtype = typeof (val)
132
+ if ! any (valtype .< : valids)
133
+ throw (IncompatibleADTypeError (valtype, adtype (context)))
134
+ end
135
+ end
136
+ return nothing
137
+ end
138
+
139
+ # A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
140
+ # context, and then call check_adtype on the result before returning the results from the
141
+ # child context.
142
+
143
+ function DynamicPPL. tilde_assume (context:: ADTypeCheckContext , right, vn, vi)
144
+ value, logp, vi = DynamicPPL. tilde_assume (
145
+ DynamicPPL. childcontext (context), right, vn, vi
146
+ )
147
+ check_adtype (context, vi)
148
+ return value, logp, vi
149
+ end
150
+
151
+ function DynamicPPL. tilde_assume (
152
+ rng:: Random.AbstractRNG , context:: ADTypeCheckContext , sampler, right, vn, vi
153
+ )
154
+ value, logp, vi = DynamicPPL. tilde_assume (
155
+ rng, DynamicPPL. childcontext (context), sampler, right, vn, vi
156
+ )
157
+ check_adtype (context, vi)
158
+ return value, logp, vi
159
+ end
160
+
161
+ function DynamicPPL. tilde_observe (context:: ADTypeCheckContext , right, left, vi)
162
+ logp, vi = DynamicPPL. tilde_observe (DynamicPPL. childcontext (context), right, left, vi)
163
+ check_adtype (context, vi)
164
+ return logp, vi
165
+ end
166
+
167
+ function DynamicPPL. tilde_observe (context:: ADTypeCheckContext , sampler, right, left, vi)
168
+ logp, vi = DynamicPPL. tilde_observe (
169
+ DynamicPPL. childcontext (context), sampler, right, left, vi
170
+ )
171
+ check_adtype (context, vi)
172
+ return logp, vi
173
+ end
174
+
175
+ """
176
+ All the ADTypes on which we want to run the tests.
177
+ """
178
+ ADTYPES = [
179
+ Turing. AutoForwardDiff (),
180
+ Turing. AutoReverseDiff (; compile= false ),
181
+ Turing. AutoMooncake (; config= nothing ),
182
+ ]
183
+
184
+ @testset verbose = true " AD / ADTypeCheckContext" begin
185
+ # This testset ensures that samplers don't accidentally override the AD
186
+ # backend set in it.
187
+ @testset " Check ADType" begin
188
+ seed = 123
189
+ alg = HMC (0.1 , 10 ; adtype= adtype)
190
+ m = DynamicPPL. contextualize (
191
+ gdemo_default, ADTypeCheckContext (adtype, gdemo_default. context)
192
+ )
193
+ # These will error if the adbackend being used is not the one set.
194
+ sample (StableRNG (seed), m, alg, 10 )
195
+ end
196
+ end
11
197
12
198
@testset verbose = true " AD / SamplingContext" begin
13
199
# AD tests for gradient-based samplers need to be run with SamplingContext
14
200
# because samplers can potentially use this to define custom behaviour in
15
201
# the tilde-pipeline and thus change the code executed during model
16
202
# evaluation.
17
- @testset " adtype=$adtype " for adtype in adbackends
203
+ @testset " adtype=$adtype " for adtype in ADTYPES
18
204
@testset " alg=$alg " for alg in [
19
205
HMC (0.1 , 10 ; adtype= adtype),
20
206
HMCDA (0.8 , 0.75 ; adtype= adtype),
@@ -30,16 +216,6 @@ using ..ADUtils: ADTypeCheckContext, adbackends
30
216
@test run_ad (model, adtype; context= ctx, test= true , benchmark= false ) isa Any
31
217
end
32
218
end
33
-
34
- @testset " Check ADType" begin
35
- seed = 123
36
- alg = HMC (0.1 , 10 ; adtype= adtype)
37
- m = DynamicPPL. contextualize (
38
- gdemo_default, ADTypeCheckContext (adtype, gdemo_default. context)
39
- )
40
- # These will error if the adbackend being used is not the one set.
41
- sample (StableRNG (seed), m, alg, 10 )
42
- end
43
219
end
44
220
end
45
221
49
225
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
50
226
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
51
227
# functions)
52
- @testset " adtype=$adtype " for adtype in adbackends
228
+ @testset " adtype=$adtype " for adtype in ADTYPES
53
229
@testset " model=$(model. f) " for model in DEMO_MODELS
54
230
# All the demo models have variables `s` and `m`, so we'll pretend
55
231
# that we're using a Gibbs sampler where both of them are sampled
0 commit comments