@@ -17,70 +17,66 @@ using Test: @test, @test_throws, @testset
17
17
using Turing
18
18
19
19
@testset " Testing inference.jl with $adbackend " for adbackend in ADUtils. adbackends
20
- # Only test threading if 1.3+.
21
- if VERSION > v " 1.2"
22
- @testset " threaded sampling" begin
23
- # Test that chains with the same seed will sample identically.
24
- @testset " rng" begin
25
- model = gdemo_default
26
-
27
- # multithreaded sampling with PG causes segfaults on Julia 1.5.4
28
- # https://github.com/TuringLang/Turing.jl/issues/1571
29
- samplers = @static if VERSION <= v " 1.5.3" || VERSION >= v " 1.6.0"
30
- (
31
- HMC (0.1 , 7 ; adtype= adbackend),
32
- PG (10 ),
33
- IS (),
34
- MH (),
35
- Gibbs (PG (3 , :s ), HMC (0.4 , 8 , :m ; adtype= adbackend)),
36
- Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
37
- )
38
- else
39
- (
40
- HMC (0.1 , 7 ; adtype= adbackend),
41
- IS (),
42
- MH (),
43
- Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
44
- )
45
- end
46
- for sampler in samplers
47
- Random. seed! (5 )
48
- chain1 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
20
+ @testset " threaded sampling" begin
21
+ # Test that chains with the same seed will sample identically.
22
+ @testset " rng" begin
23
+ model = gdemo_default
24
+
25
+ # multithreaded sampling with PG causes segfaults on Julia 1.5.4
26
+ # https://github.com/TuringLang/Turing.jl/issues/1571
27
+ samplers = @static if VERSION <= v " 1.5.3" || VERSION >= v " 1.6.0"
28
+ (
29
+ HMC (0.1 , 7 ; adtype= adbackend),
30
+ PG (10 ),
31
+ IS (),
32
+ MH (),
33
+ Gibbs (PG (3 , :s ), HMC (0.4 , 8 , :m ; adtype= adbackend)),
34
+ Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
35
+ )
36
+ else
37
+ (
38
+ HMC (0.1 , 7 ; adtype= adbackend),
39
+ IS (),
40
+ MH (),
41
+ Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
42
+ )
43
+ end
44
+ for sampler in samplers
45
+ Random. seed! (5 )
46
+ chain1 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
49
47
50
- Random. seed! (5 )
51
- chain2 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
48
+ Random. seed! (5 )
49
+ chain2 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
52
50
53
- @test chain1. value == chain2. value
54
- end
51
+ @test chain1. value == chain2. value
52
+ end
55
53
56
- # Should also be stable with am explicit RNG
57
- seed = 5
58
- rng = Random. MersenneTwister (seed)
59
- for sampler in samplers
60
- Random. seed! (rng, seed)
61
- chain1 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
54
+ # Should also be stable with am explicit RNG
55
+ seed = 5
56
+ rng = Random. MersenneTwister (seed)
57
+ for sampler in samplers
58
+ Random. seed! (rng, seed)
59
+ chain1 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
62
60
63
- Random. seed! (rng, seed)
64
- chain2 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
61
+ Random. seed! (rng, seed)
62
+ chain2 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
65
63
66
- @test chain1. value == chain2. value
67
- end
64
+ @test chain1. value == chain2. value
68
65
end
66
+ end
69
67
70
- # Smoke test for default sample call.
71
- Random. seed! (100 )
72
- chain = sample (
73
- gdemo_default, HMC (0.1 , 7 ; adtype= adbackend), MCMCThreads (), 1000 , 4
74
- )
75
- check_gdemo (chain)
68
+ # Smoke test for default sample call.
69
+ Random. seed! (100 )
70
+ chain = sample (gdemo_default, HMC (0.1 , 7 ; adtype= adbackend), MCMCThreads (), 1000 , 4 )
71
+ check_gdemo (chain)
76
72
77
- # run sampler: progress logging should be disabled and
78
- # it should return a Chains object
79
- sampler = Sampler (HMC (0.1 , 7 ; adtype= adbackend), gdemo_default)
80
- chains = sample (gdemo_default, sampler, MCMCThreads (), 1000 , 4 )
81
- @test chains isa MCMCChains. Chains
82
- end
73
+ # run sampler: progress logging should be disabled and
74
+ # it should return a Chains object
75
+ sampler = Sampler (HMC (0.1 , 7 ; adtype= adbackend), gdemo_default)
76
+ chains = sample (gdemo_default, sampler, MCMCThreads (), 1000 , 4 )
77
+ @test chains isa MCMCChains. Chains
83
78
end
79
+
84
80
@testset " chain save/resume" begin
85
81
Random. seed! (1234 )
86
82
0 commit comments