@@ -2,7 +2,6 @@ module InferenceTests
2
2
3
3
using .. Models: gdemo_d, gdemo_default
4
4
using .. NumericalTests: check_gdemo, check_numerical
5
- import .. ADUtils
6
5
using Distributions: Bernoulli, Beta, InverseGamma, Normal
7
6
using Distributions: sample
8
7
import DynamicPPL
@@ -17,8 +16,9 @@ import Mooncake
17
16
using Test: @test , @test_throws , @testset
18
17
using Turing
19
18
20
- @testset " Testing inference.jl with $adbackend " for adbackend in ADUtils. adbackends
21
- @info " Starting Inference.jl tests with $adbackend "
19
+ @testset verbose = true " Testing Inference.jl" begin
20
+ @info " Starting Inference.jl tests"
21
+
22
22
seed = 23
23
23
24
24
@testset " threaded sampling" begin
@@ -27,12 +27,12 @@ using Turing
27
27
model = gdemo_default
28
28
29
29
samplers = (
30
- HMC (0.1 , 7 ; adtype = adbackend ),
30
+ HMC (0.1 , 7 ),
31
31
PG (10 ),
32
32
IS (),
33
33
MH (),
34
- Gibbs (:s => PG (3 ), :m => HMC (0.4 , 8 ; adtype = adbackend )),
35
- Gibbs (:s => HMC (0.1 , 5 ; adtype = adbackend ), :m => ESS ()),
34
+ Gibbs (:s => PG (3 ), :m => HMC (0.4 , 8 )),
35
+ Gibbs (:s => HMC (0.1 , 5 ), :m => ESS ()),
36
36
)
37
37
for sampler in samplers
38
38
Random. seed! (5 )
@@ -44,7 +44,7 @@ using Turing
44
44
@test chain1. value == chain2. value
45
45
end
46
46
47
- # Should also be stable with am explicit RNG
47
+ # Should also be stable with an explicit RNG
48
48
seed = 5
49
49
rng = Random. MersenneTwister (seed)
50
50
for sampler in samplers
@@ -61,27 +61,22 @@ using Turing
61
61
# Smoke test for default sample call.
62
62
@testset " gdemo_default" begin
63
63
chain = sample (
64
- StableRNG (seed),
65
- gdemo_default,
66
- HMC (0.1 , 7 ; adtype= adbackend),
67
- MCMCThreads (),
68
- 1_000 ,
69
- 4 ,
64
+ StableRNG (seed), gdemo_default, HMC (0.1 , 7 ), MCMCThreads (), 1_000 , 4
70
65
)
71
66
check_gdemo (chain)
72
67
73
68
# run sampler: progress logging should be disabled and
74
69
# it should return a Chains object
75
- sampler = Sampler (HMC (0.1 , 7 ; adtype = adbackend ))
70
+ sampler = Sampler (HMC (0.1 , 7 ))
76
71
chains = sample (StableRNG (seed), gdemo_default, sampler, MCMCThreads (), 10 , 4 )
77
72
@test chains isa MCMCChains. Chains
78
73
end
79
74
end
80
75
81
76
@testset " chain save/resume" begin
82
- alg1 = HMCDA (1000 , 0.65 , 0.15 ; adtype = adbackend )
77
+ alg1 = HMCDA (1000 , 0.65 , 0.15 )
83
78
alg2 = PG (20 )
84
- alg3 = Gibbs (:s => PG (30 ), :m => HMC (0.2 , 4 ; adtype = adbackend ))
79
+ alg3 = Gibbs (:s => PG (30 ), :m => HMC (0.2 , 4 ))
85
80
86
81
chn1 = sample (StableRNG (seed), gdemo_default, alg1, 10_000 ; save_state= true )
87
82
check_gdemo (chn1)
@@ -260,7 +255,7 @@ using Turing
260
255
261
256
smc = SMC ()
262
257
pg = PG (10 )
263
- gibbs = Gibbs (:p => HMC (0.2 , 3 ; adtype = adbackend ), :x => PG (10 ))
258
+ gibbs = Gibbs (:p => HMC (0.2 , 3 ), :x => PG (10 ))
264
259
265
260
chn_s = sample (StableRNG (seed), testbb (obs), smc, 200 )
266
261
chn_p = sample (StableRNG (seed), testbb (obs), pg, 200 )
@@ -273,22 +268,17 @@ using Turing
273
268
274
269
@testset " forbid global" begin
275
270
xs = [1.5 2.0 ]
276
- # xx = 1
277
271
278
272
@model function fggibbstest (xs)
279
273
s ~ InverseGamma (2 , 3 )
280
274
m ~ Normal (0 , sqrt (s))
281
- # xx ~ Normal(m, sqrt(s)) # this is illegal
282
-
283
275
for i in 1 : length (xs)
284
276
xs[i] ~ Normal (m, sqrt (s))
285
- # for xx in xs
286
- # xx ~ Normal(m, sqrt(s))
287
277
end
288
278
return s, m
289
279
end
290
280
291
- gibbs = Gibbs (:s => PG (10 ), :m => HMC (0.4 , 8 ; adtype = adbackend ))
281
+ gibbs = Gibbs (:s => PG (10 ), :m => HMC (0.4 , 8 ))
292
282
chain = sample (StableRNG (seed), fggibbstest (xs), gibbs, 2 )
293
283
end
294
284
@@ -353,7 +343,7 @@ using Turing
353
343
)
354
344
end
355
345
356
- # TODO (mhauru) What is this testing? Why does it not use the looped-over adbackend?
346
+ # TODO (mhauru) What is this testing? Why does it use a different adbackend?
357
347
@testset " new interface" begin
358
348
obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
359
349
@@ -382,9 +372,7 @@ using Turing
382
372
end
383
373
end
384
374
385
- chain = sample (
386
- StableRNG (seed), noreturn ([1.5 2.0 ]), HMC (0.1 , 10 ; adtype= adbackend), 4000
387
- )
375
+ chain = sample (StableRNG (seed), noreturn ([1.5 2.0 ]), HMC (0.1 , 10 ), 4000 )
388
376
check_numerical (chain, [:s , :m ], [49 / 24 , 7 / 6 ])
389
377
end
390
378
@@ -415,7 +403,7 @@ using Turing
415
403
end
416
404
417
405
@testset " sample" begin
418
- alg = Gibbs (:m => HMC (0.2 , 3 ; adtype = adbackend ), :s => PG (10 ))
406
+ alg = Gibbs (:m => HMC (0.2 , 3 ), :s => PG (10 ))
419
407
chn = sample (StableRNG (seed), gdemo_default, alg, 10 )
420
408
end
421
409
@@ -427,7 +415,7 @@ using Turing
427
415
return s, m
428
416
end
429
417
430
- alg = HMC (0.01 , 5 ; adtype = adbackend )
418
+ alg = HMC (0.01 , 5 )
431
419
x = randn (100 )
432
420
res = sample (StableRNG (seed), vdemo1 (x), alg, 10 )
433
421
@@ -442,7 +430,7 @@ using Turing
442
430
443
431
# Vector assumptions
444
432
N = 10
445
- alg = HMC (0.2 , 4 ; adtype = adbackend )
433
+ alg = HMC (0.2 , 4 )
446
434
447
435
@model function vdemo3 ()
448
436
x = Vector {Real} (undef, N)
@@ -497,7 +485,7 @@ using Turing
497
485
return s, m
498
486
end
499
487
500
- alg = HMC (0.01 , 5 ; adtype = adbackend )
488
+ alg = HMC (0.01 , 5 )
501
489
x = randn (100 )
502
490
res = sample (StableRNG (seed), vdemo1 (x), alg, 10 )
503
491
@@ -507,12 +495,12 @@ using Turing
507
495
end
508
496
509
497
D = 2
510
- alg = HMC (0.01 , 5 ; adtype = adbackend )
498
+ alg = HMC (0.01 , 5 )
511
499
res = sample (StableRNG (seed), vdemo2 (randn (D, 100 )), alg, 10 )
512
500
513
501
# Vector assumptions
514
502
N = 10
515
- alg = HMC (0.2 , 4 ; adtype = adbackend )
503
+ alg = HMC (0.2 , 4 )
516
504
517
505
@model function vdemo3 ()
518
506
x = Vector {Real} (undef, N)
@@ -559,7 +547,7 @@ using Turing
559
547
560
548
@testset " Type parameters" begin
561
549
N = 10
562
- alg = HMC (0.01 , 5 ; adtype = adbackend )
550
+ alg = HMC (0.01 , 5 )
563
551
x = randn (1000 )
564
552
@model function vdemo1 (:: Type{T} = Float64) where {T}
565
553
x = Vector {T} (undef, N)
0 commit comments