@@ -181,9 +181,39 @@ ADTYPES = [
181
181
Turing. AutoMooncake (; config= nothing ),
182
182
]
183
183
184
+ # Check that ADTypeCheckContext itself works as expected.
185
+ @testset " ADTypeCheckContext" begin
186
+ @model test_model () = x ~ Normal (0 , 1 )
187
+ tm = test_model ()
188
+ adtypes = (
189
+ Turing. AutoForwardDiff (),
190
+ Turing. AutoReverseDiff (),
191
+ # TODO : Mooncake
192
+ # Turing.AutoMooncake(config=nothing),
193
+ )
194
+ for actual_adtype in adtypes
195
+ sampler = Turing. HMC (0.1 , 5 ; adtype= actual_adtype)
196
+ for expected_adtype in adtypes
197
+ contextualised_tm = DynamicPPL. contextualize (
198
+ tm, ADTypeCheckContext (expected_adtype, tm. context)
199
+ )
200
+ @testset " Expected: $expected_adtype , Actual: $actual_adtype " begin
201
+ if actual_adtype == expected_adtype
202
+ # Check that this does not throw an error.
203
+ Turing. sample (contextualised_tm, sampler, 2 )
204
+ else
205
+ @test_throws AbstractWrongADBackendError Turing. sample (
206
+ contextualised_tm, sampler, 2
207
+ )
208
+ end
209
+ end
210
+ end
211
+ end
212
+ end
213
+
184
214
@testset verbose = true " AD / ADTypeCheckContext" begin
185
- # This testset ensures that samplers don't accidentally override the AD
186
- # backend set in it.
215
+ # This testset ensures that samplers or optimisers don't accidentally
216
+ # override the AD backend set in it.
187
217
@testset " Check ADType" begin
188
218
seed = 123
189
219
alg = HMC (0.1 , 10 ; adtype= adtype)
@@ -192,6 +222,8 @@ ADTYPES = [
192
222
)
193
223
# These will error if the adbackend being used is not the one set.
194
224
sample (StableRNG (seed), m, alg, 10 )
225
+ maximum_likelihood (m; adtype= adbackend)
226
+ maximum_a_posteriori (m; adtype= adbackend)
195
227
end
196
228
end
197
229
0 commit comments