Skip to content

Commit 1845c9a

Browse files
committed
tidy
1 parent 8b8bc8d commit 1845c9a

File tree

6 files changed

+29
-29
lines changed

6 files changed

+29
-29
lines changed

test/ad.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,9 @@ end
253253
conditioned_model = Turing.Inference.make_conditional(
254254
model, varnames, deepcopy(global_vi)
255255
)
256-
rng = StableRNG(123)
257-
@test run_ad(model, adtype; test=true, benchmark=false) isa Any
256+
@test run_ad(
257+
model, adtype; rng=StableRNG(123), test=true, benchmark=false
258+
) isa Any
258259
end
259260
end
260261
end

test/ext/dynamichmc.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@ using StableRNGs: StableRNG
1111
using Turing
1212

1313
@testset "TuringDynamicHMCExt" begin
14-
rng = StableRNG(100)
1514
spl = externalsampler(DynamicHMC.NUTS())
16-
chn = sample(rng, gdemo_default, spl, 10_000)
15+
chn = sample(StableRNG(100), gdemo_default, spl, 10_000)
1716
check_gdemo(chn)
1817
end
1918

test/mcmc/callbacks.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using Test, Turing, AbstractMCMC, Random, Distributions, LinearAlgebra
88
end
99

1010
@testset "AbstractMCMC Callbacks Interface" begin
11-
rng = Random.default_rng()
1211
model = test_normals()
1312

1413
samplers = [
@@ -24,7 +23,10 @@ end
2423
for (name, sampler) in samplers
2524
@testset "$name" begin
2625
t1, s1 = AbstractMCMC.step(
27-
rng, model, sampler; initial_params=Turing.Inference.init_strategy(sampler)
26+
Random.default_rng(),
27+
model,
28+
sampler;
29+
initial_params=Turing.Inference.init_strategy(sampler),
2830
)
2931

3032
# ParamsWithStats returns named params (not θ[i])
@@ -46,6 +48,7 @@ end
4648
# NUTS second step has full AHMC transition metrics
4749
@testset "NUTS Transition Metrics" begin
4850
sampler = NUTS(10, 0.65)
51+
rng = Random.default_rng()
4952
t1, s1 = AbstractMCMC.step(
5053
rng, model, sampler; initial_params=Turing.Inference.init_strategy(sampler)
5154
)

test/mcmc/emcee.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@ using Turing
1111

1212
@testset "emcee.jl" begin
1313
@testset "gdemo" begin
14-
rng = StableRNG(9876)
15-
1614
n_samples = 1000
1715
n_walkers = 250
18-
1916
spl = Emcee(n_walkers, 2.0)
20-
chain = sample(rng, gdemo_default, spl, n_samples)
17+
chain = sample(StableRNG(9876), gdemo_default, spl, n_samples)
2118
check_gdemo(chain)
2219
end
2320

test/mcmc/gibbs.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,8 @@ end
578578
# Sampler to use for Gibbs components.
579579
hmc = HMC(0.1, 32)
580580
sampler = Gibbs(@varname(s) => hmc, @varname(m) => hmc)
581-
rng = StableRNG(42)
582581
chain = sample(
583-
rng,
582+
StableRNG(42),
584583
model,
585584
sampler,
586585
MCMCThreads(),
@@ -625,14 +624,13 @@ end
625624
end
626625

627626
@testset "multiple varnames" begin
628-
rng = Random.default_rng()
629-
630627
@testset "with both `s` and `m` as random" begin
631628
model = gdemo(1.5, 2.0)
632629
vns = (@varname(s), @varname(m))
633630
spl = Gibbs(vns => MH())
634631

635632
# `step`
633+
rng = Random.default_rng()
636634
transition, state = AbstractMCMC.step(rng, model, spl)
637635
check_transition_varnames(transition, vns)
638636
for _ in 1:5
@@ -641,8 +639,7 @@ end
641639
end
642640

643641
# `sample`
644-
rng = StableRNG(42)
645-
chain = sample(rng, model, spl, 1_000; progress=false)
642+
chain = sample(StableRNG(42), model, spl, 1_000; progress=false)
646643
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4)
647644
end
648645

@@ -652,6 +649,7 @@ end
652649
spl = Gibbs(vns => MH())
653650

654651
# `step`
652+
rng = Random.default_rng()
655653
transition, state = AbstractMCMC.step(rng, model, spl)
656654
check_transition_varnames(transition, vns)
657655
for _ in 1:5
@@ -695,7 +693,6 @@ end
695693
end
696694

697695
@testset "CSMC + ESS" begin
698-
rng = Random.default_rng()
699696
model = MoGtest_default
700697
spl = Gibbs(
701698
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15),
@@ -711,6 +708,7 @@ end
711708
@varname(mu2)
712709
)
713710
# `step`
711+
rng = Random.default_rng()
714712
transition, state = AbstractMCMC.step(rng, model, spl)
715713
check_transition_varnames(transition, vns)
716714
for _ in 1:5
@@ -719,17 +717,16 @@ end
719717
end
720718

721719
# Sample!
722-
rng = StableRNG(42)
723-
chain = sample(rng, MoGtest_default, spl, 1000; progress=false)
720+
chain = sample(StableRNG(42), MoGtest_default, spl, 1000; progress=false)
724721
check_MoGtest_default(chain; atol=0.2)
725722
end
726723

727724
@testset "CSMC + ESS (usage of implicit varname)" begin
728-
rng = Random.default_rng()
729725
model = MoGtest_default_z_vector
730726
spl = Gibbs(@varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS())
731727
vns = (@varname(z), @varname(mu1), @varname(mu2))
732728
# `step`
729+
rng = Random.default_rng()
733730
transition, state = AbstractMCMC.step(rng, model, spl)
734731
check_transition_varnames(transition, vns)
735732
for _ in 1:5
@@ -738,8 +735,7 @@ end
738735
end
739736

740737
# Sample!
741-
rng = StableRNG(42)
742-
chain = sample(rng, model, spl, 1000; progress=false)
738+
chain = sample(StableRNG(42), model, spl, 1000; progress=false)
743739
check_MoGtest_default_z_vector(chain; atol=0.2)
744740
end
745741

@@ -775,9 +771,14 @@ end
775771
]
776772
@testset "$(sampler_inner)" for sampler_inner in samplers_inner
777773
sampler = Gibbs(@varname(m1) => sampler_inner, @varname(m2) => sampler_inner)
778-
rng = StableRNG(42)
779774
chain = sample(
780-
rng, model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
775+
StableRNG(42),
776+
model,
777+
sampler,
778+
1000;
779+
discard_initial=1000,
780+
thinning=10,
781+
n_adapts=0,
781782
)
782783
check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
783784
check_logp_correct(sampler_inner)

test/mcmc/sghmc.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ using Turing
2121
end
2222

2323
@testset "sghmc inference" begin
24-
rng = StableRNG(123)
2524
alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5)
26-
chain = sample(rng, gdemo_default, alg, 10_000)
25+
chain = sample(StableRNG(123), gdemo_default, alg, 10_000)
2726
check_gdemo(chain; atol=0.1)
2827
end
2928

@@ -39,9 +38,9 @@ end
3938
end
4039

4140
@testset "sgld inference" begin
42-
rng = StableRNG(1)
43-
44-
chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000)
41+
chain = sample(
42+
StableRNG(1), gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000
43+
)
4544
check_gdemo(chain; atol=0.25)
4645

4746
# Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh)

0 commit comments

Comments
 (0)