@@ -4,9 +4,11 @@ using Turing
4
4
using DynamicPPL
5
5
using DynamicPPL. TestUtils: DEMO_MODELS
6
6
using DynamicPPL. TestUtils. AD: run_ad
7
+ using Random: Random
7
8
using StableRNGs: StableRNG
8
9
using Test
9
10
using .. Models: gdemo_default
11
+ import ForwardDiff, ReverseDiff, Mooncake
10
12
11
13
""" Element types that are always valid for a VarInfo regardless of ADType."""
12
14
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
@@ -181,17 +183,49 @@ ADTYPES = [
181
183
Turing. AutoMooncake (; config= nothing ),
182
184
]
183
185
186
+ # Check that ADTypeCheckContext itself works as expected.
187
+ @testset " ADTypeCheckContext" begin
188
+ @model test_model () = x ~ Normal (0 , 1 )
189
+ tm = test_model ()
190
+ adtypes = (
191
+ Turing. AutoForwardDiff (),
192
+ Turing. AutoReverseDiff (),
193
+ # TODO : Mooncake
194
+ # Turing.AutoMooncake(config=nothing),
195
+ )
196
+ for actual_adtype in adtypes
197
+ sampler = Turing. HMC (0.1 , 5 ; adtype= actual_adtype)
198
+ for expected_adtype in adtypes
199
+ contextualised_tm = DynamicPPL. contextualize (
200
+ tm, ADTypeCheckContext (expected_adtype, tm. context)
201
+ )
202
+ @testset " Expected: $expected_adtype , Actual: $actual_adtype " begin
203
+ if actual_adtype == expected_adtype
204
+ # Check that this does not throw an error.
205
+ Turing. sample (contextualised_tm, sampler, 2 )
206
+ else
207
+ @test_throws AbstractWrongADBackendError Turing. sample (
208
+ contextualised_tm, sampler, 2
209
+ )
210
+ end
211
+ end
212
+ end
213
+ end
214
+ end
215
+
184
216
@testset verbose = true " AD / ADTypeCheckContext" begin
185
- # This testset ensures that samplers don't accidentally override the AD
186
- # backend set in it.
187
- @testset " Check ADType " begin
217
+ # This testset ensures that samplers or optimisers don't accidentally
218
+ # override the AD backend set in it.
219
+ @testset " adtype= $adtype " for adtype in ADTYPES
188
220
seed = 123
189
221
alg = HMC (0.1 , 10 ; adtype= adtype)
190
222
m = DynamicPPL. contextualize (
191
223
gdemo_default, ADTypeCheckContext (adtype, gdemo_default. context)
192
224
)
193
225
# These will error if the adbackend being used is not the one set.
194
226
sample (StableRNG (seed), m, alg, 10 )
227
+ maximum_likelihood (m; adtype= adtype)
228
+ maximum_a_posteriori (m; adtype= adtype)
195
229
end
196
230
end
197
231
0 commit comments