@@ -7,6 +7,8 @@ import ..ADUtils
7
7
using Bijectors: Bijectors
8
8
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
9
9
using DynamicPPL: DynamicPPL, Sampler
10
+ using DynamicPPL. TestUtils. AD: run_ad
11
+ using DynamicPPL. TestUtils: DEMO_MODELS
10
12
import ForwardDiff
11
13
using HypothesisTests: ApproximateTwoSampleKSTest, pvalue
12
14
import ReverseDiff
@@ -18,9 +20,41 @@ import Mooncake
18
20
using Test: @test , @test_logs , @testset , @test_throws
19
21
using Turing
20
22
21
- @testset " Testing hmc.jl with $adbackend " for adbackend in ADUtils. adbackends
22
- @info " Starting HMC tests with $adbackend "
23
+ @testset " AD / hmc.jl" begin
24
+ # AD tests need to be run with SamplingContext because samplers can potentially
25
+ # use this to define custom behaviour in the tilde-pipeline and thus change the
26
+ # code executed during model evaluation.
27
+ @testset " adtype=$adtype " for adtype in ADUtils. adbackends
28
+ @testset " alg=$alg " for alg in [
29
+ HMC (0.1 , 10 ; adtype= adtype),
30
+ HMCDA (0.8 , 0.75 ; adtype= adtype),
31
+ NUTS (1000 , 0.8 ; adtype= adtype),
32
+ ]
33
+ @info " Testing AD for $alg "
34
+
35
+ @testset " model=$(model. f) " for model in DEMO_MODELS
36
+ rng = StableRNG (123 )
37
+ ctx = DynamicPPL. SamplingContext (rng, DynamicPPL. Sampler (alg))
38
+ @test run_ad (model, adtype; context= ctx, test= true , benchmark= false ) isa Any
39
+ end
40
+ end
41
+
42
+ @testset " Check ADType" begin
43
+ seed = 123
44
+ alg = HMC (0.1 , 10 ; adtype= adbackend)
45
+ m = DynamicPPL. contextualize (
46
+ gdemo_default, ADTypeCheckContext (adbackend, gdemo_default. context)
47
+ )
48
+ # These will error if the adbackend being used is not the one set.
49
+ sample (StableRNG (seed), m, alg, 10 )
50
+ end
51
+ end
52
+ end
53
+
54
+ @testset " Testing hmc.jl" begin
55
+ @info " Starting HMC tests"
23
56
seed = 123
57
+ adbackend = Turing. DEFAULT_ADTYPE
24
58
25
59
@testset " constrained bounded" begin
26
60
obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
@@ -65,12 +99,6 @@ using Turing
65
99
check_numerical (chain, [" ps[1]" , " ps[2]" ], [5 / 16 , 11 / 16 ]; atol= 0.015 )
66
100
end
67
101
68
- @testset " hmc reverse diff" begin
69
- alg = HMC (0.1 , 10 ; adtype= adbackend)
70
- res = sample (StableRNG (seed), gdemo_default, alg, 4_000 )
71
- check_gdemo (res; rtol= 0.1 )
72
- end
73
-
74
102
# Test the sampling of a matrix-value distribution.
75
103
@testset " matrix support" begin
76
104
dist = Wishart (7 , [1 0.5 ; 0.5 1 ])
@@ -211,20 +239,20 @@ using Turing
211
239
end
212
240
213
241
@testset " prior" begin
242
+ # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
243
+ # which means that it's _very_ difficult to find a good tolerance in the test below:)
244
+ prior_dist = truncated (Normal (3 , 1 ); lower= 0 )
245
+
214
246
@model function demo_hmc_prior ()
215
- # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
216
- # which means that it's _very_ difficult to find a good tolerance in the test below:)
217
- s ~ truncated (Normal (3 , 1 ); lower= 0 )
247
+ s ~ prior_dist
218
248
return m ~ Normal (0 , sqrt (s))
219
249
end
220
250
alg = NUTS (1000 , 0.8 ; adtype= adbackend)
221
251
gdemo_default_prior = DynamicPPL. contextualize (
222
252
demo_hmc_prior (), DynamicPPL. PriorContext ()
223
253
)
224
254
chain = sample (gdemo_default_prior, alg, 5_000 ; initial_params= [3.0 , 0.0 ])
225
- check_numerical (
226
- chain, [:s , :m ], [mean (truncated (Normal (3 , 1 ); lower= 0 )), 0 ]; atol= 0.2
227
- )
255
+ check_numerical (chain, [:s , :m ], [mean (prior_dist), 0 ]; atol= 0.2 )
228
256
end
229
257
230
258
@testset " warning for difficult init params" begin
@@ -292,8 +320,8 @@ using Turing
292
320
293
321
# Extract the `x` like this because running `generated_quantities` was how
294
322
# the issue was discovered, hence we also want to make sure that it works.
295
- results = generated_quantities (model, chain)
296
- results_prior = generated_quantities (model, chain_prior)
323
+ results = returned (model, chain)
324
+ results_prior = returned (model, chain_prior)
297
325
298
326
# Make sure none of the samples in the chains resulted in errors.
299
327
@test all (! isnothing, results)
@@ -315,15 +343,6 @@ using Turing
315
343
@test Turing. Inference. getstepsize (spl, hmc_state) isa Float64
316
344
end
317
345
end
318
-
319
- @testset " Check ADType" begin
320
- alg = HMC (0.1 , 10 ; adtype= adbackend)
321
- m = DynamicPPL. contextualize (
322
- gdemo_default, ADTypeCheckContext (adbackend, gdemo_default. context)
323
- )
324
- # These will error if the adbackend being used is not the one set.
325
- sample (StableRNG (seed), m, alg, 10 )
326
- end
327
346
end
328
347
329
348
end
0 commit comments