Skip to content

Commit 269081e

Browse files
mhaurupenelopeysm
andauthored
Run JuliaFormatter on more files, remove trailing whitespace (#2374)
Co-authored-by: Penelope Yong <[email protected]>
1 parent 08331c5 commit 269081e

File tree

5 files changed

+58
-61
lines changed

5 files changed

+58
-61
lines changed

.JuliaFormatter.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,4 @@ ignore = [
88
# https://github.com/TuringLang/Turing.jl/pull/2328/files
99
"src/experimental/gibbs.jl",
1010
"test/experimental/gibbs.jl",
11-
# https://github.com/TuringLang/Turing.jl/pull/1887 # Enzyme PR
12-
"test/mcmc/hmc.jl",
13-
"test/mcmc/sghmc.jl",
1411
]

.github/workflows/DocsNav.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ jobs:
3232
3333
# Define the URL of the navbar to be used
3434
NAVBAR_URL="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/main/assets/scripts/TuringNavbar.html"
35-
35+
3636
# Update all HTML files in the current directory (gh-pages root)
3737
./insert_navbar.sh . $NAVBAR_URL
38-
38+
3939
# Remove the insert_navbar.sh file
4040
rm insert_navbar.sh
41-
41+
4242
# Check if there are any changes
4343
if [[ -n $(git status -s) ]]; then
4444
git add .

src/mcmc/mh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Specifying a single distribution implies the use of static MH:
5454
5555
```julia
5656
# Use a static proposal for s² (which happens to be the same
57-
# as the prior) and a static proposal for m (note that this
57+
# as the prior) and a static proposal for m (note that this
5858
# isn't a random walk proposal).
5959
chain = sample(
6060
gdemo(1.5, 2.0),

test/mcmc/hmc.jl

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,60 +22,59 @@ using Turing
2222
# Set a seed
2323
rng = StableRNG(123)
2424
@testset "constrained bounded" begin
25-
obs = [0,1,0,1,1,1,1,1,1,1]
25+
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
2626

2727
@model function constrained_test(obs)
28-
p ~ Beta(2,2)
29-
for i = 1:length(obs)
28+
p ~ Beta(2, 2)
29+
for i in 1:length(obs)
3030
obs[i] ~ Bernoulli(p)
3131
end
32-
p
32+
return p
3333
end
3434

3535
chain = sample(
3636
rng,
3737
constrained_test(obs),
3838
HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5)
39-
1000)
39+
1000,
40+
)
4041

41-
check_numerical(chain, [:p], [10/14], atol=0.1)
42+
check_numerical(chain, [:p], [10 / 14]; atol=0.1)
4243
end
4344
@testset "constrained simplex" begin
44-
obs12 = [1,2,1,2,2,2,2,2,2,2]
45+
obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2]
4546

4647
@model function constrained_simplex_test(obs12)
4748
ps ~ Dirichlet(2, 3)
4849
pd ~ Dirichlet(4, 1)
49-
for i = 1:length(obs12)
50+
for i in 1:length(obs12)
5051
obs12[i] ~ Categorical(ps)
5152
end
5253
return ps
5354
end
5455

5556
chain = sample(
56-
rng,
57-
constrained_simplex_test(obs12),
58-
HMC(0.75, 2; adtype=adbackend),
59-
1000)
57+
rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000
58+
)
6059

61-
check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015)
60+
check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015)
6261
end
6362
@testset "hmc reverse diff" begin
6463
alg = HMC(0.1, 10; adtype=adbackend)
6564
res = sample(rng, gdemo_default, alg, 4000)
66-
check_gdemo(res, rtol=0.1)
65+
check_gdemo(res; rtol=0.1)
6766
end
6867
@testset "matrix support" begin
6968
@model function hmcmatrixsup()
70-
v ~ Wishart(7, [1 0.5; 0.5 1])
69+
return v ~ Wishart(7, [1 0.5; 0.5 1])
7170
end
7271

7372
model_f = hmcmatrixsup()
7473
n_samples = 1_000
7574
vs = map(1:3) do _
7675
chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples)
7776
r = reshape(Array(group(chain, :v)), n_samples, 2, 2)
78-
reshape(mean(r; dims = 1), 2, 2)
77+
reshape(mean(r; dims=1), 2, 2)
7978
end
8079

8180
@test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5
@@ -92,10 +91,10 @@ using Turing
9291
M = N ÷ 4
9392
x1s = rand(M) * 5
9493
x2s = rand(M) * 5
95-
xt1s = Array([[x1s[i]; x2s[i]] for i = 1:M])
96-
append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i = 1:M]))
97-
xt0s = Array([[x1s[i]; x2s[i] - 6] for i = 1:M])
98-
append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i = 1:M]))
94+
xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M])
95+
append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M]))
96+
xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M])
97+
append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M]))
9998

10099
xs = [xt1s; xt0s]
101100
ts = [ones(M); ones(M); zeros(M); zeros(M)]
@@ -106,20 +105,22 @@ using Turing
106105
var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior
107106

108107
@model function bnn(ts)
109-
b1 ~ MvNormal([0. ;0.; 0.],
110-
[var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior])
111-
w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
112-
w12 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
113-
w13 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
108+
b1 ~ MvNormal(
109+
[0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
110+
)
111+
w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
112+
w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
113+
w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
114114
bo ~ Normal(0, var_prior)
115115

116-
wo ~ MvNormal([0.; 0; 0],
117-
[var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior])
118-
for i = rand(1:N, 10)
116+
wo ~ MvNormal(
117+
[0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
118+
)
119+
for i in rand(1:N, 10)
119120
y = nn(xs[i], b1, w11, w12, w13, bo, wo)
120121
ts[i] ~ Bernoulli(y)
121122
end
122-
b1, w11, w12, w13, bo, wo
123+
return b1, w11, w12, w13, bo, wo
123124
end
124125

125126
# Sampling
@@ -147,7 +148,7 @@ using Turing
147148
Random.seed!(12345) # particle samplers do not support user-provided `rng` yet
148149
alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend))
149150

150-
res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000)
151+
res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000)
151152
check_gdemo(res3)
152153
end
153154

@@ -191,8 +192,8 @@ using Turing
191192
@testset "check discard" begin
192193
alg = NUTS(100, 0.8; adtype=adbackend)
193194

194-
c1 = sample(rng, gdemo_default, alg, 500, discard_adapt=true)
195-
c2 = sample(rng, gdemo_default, alg, 500, discard_adapt=false)
195+
c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true)
196+
c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false)
196197

197198
@test size(c1, 1) == 500
198199
@test size(c2, 1) == 500
@@ -210,20 +211,20 @@ using Turing
210211
# https://github.com/TuringLang/DynamicPPL.jl/issues/27
211212
@model function mwe1(::Type{T}=Float64) where {T<:Real}
212213
m = Matrix{T}(undef, 2, 3)
213-
m .~ MvNormal(zeros(2), I)
214+
return m .~ MvNormal(zeros(2), I)
214215
end
215216
@test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
216217

217218
@model function mwe2(::Type{T}=Matrix{Float64}) where {T}
218219
m = T(undef, 2, 3)
219-
m .~ MvNormal(zeros(2), I)
220+
return m .~ MvNormal(zeros(2), I)
220221
end
221222
@test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
222223

223224
# https://github.com/TuringLang/Turing.jl/issues/1308
224225
@model function mwe3(::Type{T}=Array{Float64}) where {T}
225226
m = T(undef, 2, 3)
226-
m .~ MvNormal(zeros(2), I)
227+
return m .~ MvNormal(zeros(2), I)
227228
end
228229
@test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
229230
end
@@ -241,13 +242,17 @@ using Turing
241242
@model function demo_hmc_prior()
242243
# NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
243244
# which means that it's _very_ difficult to find a good tolerance in the test below:)
244-
s ~ truncated(Normal(3, 1), lower=0)
245-
m ~ Normal(0, sqrt(s))
245+
s ~ truncated(Normal(3, 1); lower=0)
246+
return m ~ Normal(0, sqrt(s))
246247
end
247248
alg = NUTS(1000, 0.8; adtype=adbackend)
248-
gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext())
249+
gdemo_default_prior = DynamicPPL.contextualize(
250+
demo_hmc_prior(), DynamicPPL.PriorContext()
251+
)
249252
chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0])
250-
check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.2)
253+
check_numerical(
254+
chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2
255+
)
251256
end
252257

253258
@testset "warning for difficult init params" begin
@@ -262,7 +267,7 @@ using Turing
262267
@test_logs (
263268
:warn,
264269
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword",
265-
) (:info,) match_mode=:any begin
270+
) (:info,) match_mode = :any begin
266271
sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5)
267272
end
268273
end
@@ -271,7 +276,7 @@ using Turing
271276
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
272277
xs = Vector{TV}(undef, 2)
273278
xs[1] ~ Dirichlet(ones(5))
274-
xs[2] ~ Dirichlet(ones(5))
279+
return xs[2] ~ Dirichlet(ones(5))
275280
end
276281
model = vector_of_dirichlet()
277282
chain = sample(model, NUTS(), 1000)
@@ -296,15 +301,10 @@ using Turing
296301
end
297302
end
298303

299-
model = buggy_model();
300-
num_samples = 1_000;
304+
model = buggy_model()
305+
num_samples = 1_000
301306

302-
chain = sample(
303-
model,
304-
NUTS(),
305-
num_samples;
306-
initial_params=[0.5, 1.75, 1.0]
307-
)
307+
chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0])
308308
chain_prior = sample(model, Prior(), num_samples)
309309

310310
# Extract the `x` like this because running `generated_quantities` was how

test/mcmc/sghmc.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using Turing
3434

3535
alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend)
3636
chain = sample(rng, gdemo_default, alg, 10_000)
37-
check_gdemo(chain, atol=0.1)
37+
check_gdemo(chain; atol=0.1)
3838
end
3939
end
4040

@@ -58,15 +58,15 @@ end
5858
@testset "sgld inference" begin
5959
rng = StableRNG(1)
6060

61-
chain = sample(rng, gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 20_000)
62-
check_gdemo(chain, atol = 0.2)
61+
chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000)
62+
check_gdemo(chain; atol=0.2)
6363

6464
# Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh)
6565
v = get(chain, [:SGLD_stepsize, :s, :m])
6666
s_weighted = dot(v.SGLD_stepsize, v.s) / sum(v.SGLD_stepsize)
6767
m_weighted = dot(v.SGLD_stepsize, v.m) / sum(v.SGLD_stepsize)
68-
@test s_weighted 49/24 atol=0.2
69-
@test m_weighted 7/6 atol=0.2
68+
@test s_weighted 49 / 24 atol = 0.2
69+
@test m_weighted 7 / 6 atol = 0.2
7070
end
7171
end
7272

0 commit comments

Comments
 (0)