@@ -22,60 +22,59 @@ using Turing
22
22
# Set a seed
23
23
rng = StableRNG (123 )
24
24
@testset " constrained bounded" begin
25
- obs = [0 ,1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
25
+ obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
26
26
27
27
@model function constrained_test (obs)
28
- p ~ Beta (2 ,2 )
29
- for i = 1 : length (obs)
28
+ p ~ Beta (2 , 2 )
29
+ for i in 1 : length (obs)
30
30
obs[i] ~ Bernoulli (p)
31
31
end
32
- p
32
+ return p
33
33
end
34
34
35
35
chain = sample (
36
36
rng,
37
37
constrained_test (obs),
38
38
HMC (1.5 , 3 ; adtype= adbackend),# using a large step size (1.5)
39
- 1000 )
39
+ 1000 ,
40
+ )
40
41
41
- check_numerical (chain, [:p ], [10 / 14 ], atol= 0.1 )
42
+ check_numerical (chain, [:p ], [10 / 14 ]; atol= 0.1 )
42
43
end
43
44
@testset " constrained simplex" begin
44
- obs12 = [1 ,2 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
45
+ obs12 = [1 , 2 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
45
46
46
47
@model function constrained_simplex_test (obs12)
47
48
ps ~ Dirichlet (2 , 3 )
48
49
pd ~ Dirichlet (4 , 1 )
49
- for i = 1 : length (obs12)
50
+ for i in 1 : length (obs12)
50
51
obs12[i] ~ Categorical (ps)
51
52
end
52
53
return ps
53
54
end
54
55
55
56
chain = sample (
56
- rng,
57
- constrained_simplex_test (obs12),
58
- HMC (0.75 , 2 ; adtype= adbackend),
59
- 1000 )
57
+ rng, constrained_simplex_test (obs12), HMC (0.75 , 2 ; adtype= adbackend), 1000
58
+ )
60
59
61
- check_numerical (chain, [" ps[1]" , " ps[2]" ], [5 / 16 , 11 / 16 ], atol= 0.015 )
60
+ check_numerical (chain, [" ps[1]" , " ps[2]" ], [5 / 16 , 11 / 16 ]; atol= 0.015 )
62
61
end
63
62
@testset " hmc reverse diff" begin
64
63
alg = HMC (0.1 , 10 ; adtype= adbackend)
65
64
res = sample (rng, gdemo_default, alg, 4000 )
66
- check_gdemo (res, rtol= 0.1 )
65
+ check_gdemo (res; rtol= 0.1 )
67
66
end
68
67
@testset " matrix support" begin
69
68
@model function hmcmatrixsup ()
70
- v ~ Wishart (7 , [1 0.5 ; 0.5 1 ])
69
+ return v ~ Wishart (7 , [1 0.5 ; 0.5 1 ])
71
70
end
72
71
73
72
model_f = hmcmatrixsup ()
74
73
n_samples = 1_000
75
74
vs = map (1 : 3 ) do _
76
75
chain = sample (rng, model_f, HMC (0.15 , 7 ; adtype= adbackend), n_samples)
77
76
r = reshape (Array (group (chain, :v )), n_samples, 2 , 2 )
78
- reshape (mean (r; dims = 1 ), 2 , 2 )
77
+ reshape (mean (r; dims= 1 ), 2 , 2 )
79
78
end
80
79
81
80
@test maximum (abs, mean (vs) - (7 * [1 0.5 ; 0.5 1 ])) <= 0.5
@@ -92,10 +91,10 @@ using Turing
92
91
M = N ÷ 4
93
92
x1s = rand (M) * 5
94
93
x2s = rand (M) * 5
95
- xt1s = Array ([[x1s[i]; x2s[i]] for i = 1 : M])
96
- append! (xt1s, Array ([[x1s[i] - 6 ; x2s[i] - 6 ] for i = 1 : M]))
97
- xt0s = Array ([[x1s[i]; x2s[i] - 6 ] for i = 1 : M])
98
- append! (xt0s, Array ([[x1s[i] - 6 ; x2s[i]] for i = 1 : M]))
94
+ xt1s = Array ([[x1s[i]; x2s[i]] for i in 1 : M])
95
+ append! (xt1s, Array ([[x1s[i] - 6 ; x2s[i] - 6 ] for i in 1 : M]))
96
+ xt0s = Array ([[x1s[i]; x2s[i] - 6 ] for i in 1 : M])
97
+ append! (xt0s, Array ([[x1s[i] - 6 ; x2s[i]] for i in 1 : M]))
99
98
100
99
xs = [xt1s; xt0s]
101
100
ts = [ones (M); ones (M); zeros (M); zeros (M)]
@@ -106,20 +105,22 @@ using Turing
106
105
var_prior = sqrt (1.0 / alpha) # variance of the Gaussian prior
107
106
108
107
@model function bnn (ts)
109
- b1 ~ MvNormal ([0. ;0. ; 0. ],
110
- [var_prior 0. 0. ; 0. var_prior 0. ; 0. 0. var_prior])
111
- w11 ~ MvNormal ([0. ; 0. ], [var_prior 0. ; 0. var_prior])
112
- w12 ~ MvNormal ([0. ; 0. ], [var_prior 0. ; 0. var_prior])
113
- w13 ~ MvNormal ([0. ; 0. ], [var_prior 0. ; 0. var_prior])
108
+ b1 ~ MvNormal (
109
+ [0.0 ; 0.0 ; 0.0 ], [var_prior 0.0 0.0 ; 0.0 var_prior 0.0 ; 0.0 0.0 var_prior]
110
+ )
111
+ w11 ~ MvNormal ([0.0 ; 0.0 ], [var_prior 0.0 ; 0.0 var_prior])
112
+ w12 ~ MvNormal ([0.0 ; 0.0 ], [var_prior 0.0 ; 0.0 var_prior])
113
+ w13 ~ MvNormal ([0.0 ; 0.0 ], [var_prior 0.0 ; 0.0 var_prior])
114
114
bo ~ Normal (0 , var_prior)
115
115
116
- wo ~ MvNormal ([0. ; 0 ; 0 ],
117
- [var_prior 0. 0. ; 0. var_prior 0. ; 0. 0. var_prior])
118
- for i = rand (1 : N, 10 )
116
+ wo ~ MvNormal (
117
+ [0.0 ; 0 ; 0 ], [var_prior 0.0 0.0 ; 0.0 var_prior 0.0 ; 0.0 0.0 var_prior]
118
+ )
119
+ for i in rand (1 : N, 10 )
119
120
y = nn (xs[i], b1, w11, w12, w13, bo, wo)
120
121
ts[i] ~ Bernoulli (y)
121
122
end
122
- b1, w11, w12, w13, bo, wo
123
+ return b1, w11, w12, w13, bo, wo
123
124
end
124
125
125
126
# Sampling
@@ -147,7 +148,7 @@ using Turing
147
148
Random. seed! (12345 ) # particle samplers do not support user-provided `rng` yet
148
149
alg3 = Gibbs (PG (20 , :s ), HMCDA (500 , 0.8 , 0.25 , :m ; init_ϵ= 0.05 , adtype= adbackend))
149
150
150
- res3 = sample (rng, gdemo_default, alg3, 3000 , discard_initial= 1000 )
151
+ res3 = sample (rng, gdemo_default, alg3, 3000 ; discard_initial= 1000 )
151
152
check_gdemo (res3)
152
153
end
153
154
@@ -191,8 +192,8 @@ using Turing
191
192
@testset " check discard" begin
192
193
alg = NUTS (100 , 0.8 ; adtype= adbackend)
193
194
194
- c1 = sample (rng, gdemo_default, alg, 500 , discard_adapt= true )
195
- c2 = sample (rng, gdemo_default, alg, 500 , discard_adapt= false )
195
+ c1 = sample (rng, gdemo_default, alg, 500 ; discard_adapt= true )
196
+ c2 = sample (rng, gdemo_default, alg, 500 ; discard_adapt= false )
196
197
197
198
@test size (c1, 1 ) == 500
198
199
@test size (c2, 1 ) == 500
@@ -210,20 +211,20 @@ using Turing
210
211
# https://github.com/TuringLang/DynamicPPL.jl/issues/27
211
212
@model function mwe1 (:: Type{T} = Float64) where {T<: Real }
212
213
m = Matrix {T} (undef, 2 , 3 )
213
- m .~ MvNormal (zeros (2 ), I)
214
+ return m .~ MvNormal (zeros (2 ), I)
214
215
end
215
216
@test sample (rng, mwe1 (), HMC (0.2 , 4 ; adtype= adbackend), 1_000 ) isa Chains
216
217
217
218
@model function mwe2 (:: Type{T} = Matrix{Float64}) where {T}
218
219
m = T (undef, 2 , 3 )
219
- m .~ MvNormal (zeros (2 ), I)
220
+ return m .~ MvNormal (zeros (2 ), I)
220
221
end
221
222
@test sample (rng, mwe2 (), HMC (0.2 , 4 ; adtype= adbackend), 1_000 ) isa Chains
222
223
223
224
# https://github.com/TuringLang/Turing.jl/issues/1308
224
225
@model function mwe3 (:: Type{T} = Array{Float64}) where {T}
225
226
m = T (undef, 2 , 3 )
226
- m .~ MvNormal (zeros (2 ), I)
227
+ return m .~ MvNormal (zeros (2 ), I)
227
228
end
228
229
@test sample (rng, mwe3 (), HMC (0.2 , 4 ; adtype= adbackend), 1_000 ) isa Chains
229
230
end
@@ -241,13 +242,17 @@ using Turing
241
242
@model function demo_hmc_prior ()
242
243
# NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
243
244
# which means that it's _very_ difficult to find a good tolerance in the test below:)
244
- s ~ truncated (Normal (3 , 1 ), lower= 0 )
245
- m ~ Normal (0 , sqrt (s))
245
+ s ~ truncated (Normal (3 , 1 ); lower= 0 )
246
+ return m ~ Normal (0 , sqrt (s))
246
247
end
247
248
alg = NUTS (1000 , 0.8 ; adtype= adbackend)
248
- gdemo_default_prior = DynamicPPL. contextualize (demo_hmc_prior (), DynamicPPL. PriorContext ())
249
+ gdemo_default_prior = DynamicPPL. contextualize (
250
+ demo_hmc_prior (), DynamicPPL. PriorContext ()
251
+ )
249
252
chain = sample (gdemo_default_prior, alg, 10_000 ; initial_params= [3.0 , 0.0 ])
250
- check_numerical (chain, [:s , :m ], [mean (truncated (Normal (3 , 1 ); lower= 0 )), 0 ], atol= 0.2 )
253
+ check_numerical (
254
+ chain, [:s , :m ], [mean (truncated (Normal (3 , 1 ); lower= 0 )), 0 ]; atol= 0.2
255
+ )
251
256
end
252
257
253
258
@testset " warning for difficult init params" begin
@@ -262,7 +267,7 @@ using Turing
262
267
@test_logs (
263
268
:warn ,
264
269
" failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword" ,
265
- ) (:info ,) match_mode= :any begin
270
+ ) (:info ,) match_mode = :any begin
266
271
sample (demo_warn_initial_params (), NUTS (; adtype= adbackend), 5 )
267
272
end
268
273
end
@@ -271,7 +276,7 @@ using Turing
271
276
@model function vector_of_dirichlet (:: Type{TV} = Vector{Float64}) where {TV}
272
277
xs = Vector {TV} (undef, 2 )
273
278
xs[1 ] ~ Dirichlet (ones (5 ))
274
- xs[2 ] ~ Dirichlet (ones (5 ))
279
+ return xs[2 ] ~ Dirichlet (ones (5 ))
275
280
end
276
281
model = vector_of_dirichlet ()
277
282
chain = sample (model, NUTS (), 1000 )
@@ -296,15 +301,10 @@ using Turing
296
301
end
297
302
end
298
303
299
- model = buggy_model ();
300
- num_samples = 1_000 ;
304
+ model = buggy_model ()
305
+ num_samples = 1_000
301
306
302
- chain = sample (
303
- model,
304
- NUTS (),
305
- num_samples;
306
- initial_params= [0.5 , 1.75 , 1.0 ]
307
- )
307
+ chain = sample (model, NUTS (), num_samples; initial_params= [0.5 , 1.75 , 1.0 ])
308
308
chain_prior = sample (model, Prior (), num_samples)
309
309
310
310
# Extract the `x` like this because running `generated_quantities` was how
0 commit comments