Skip to content

Commit 184d592

Browse files
Remove Random.seed! from test suite, use Xoshiro/StableRNG instead (#2779)
Closes #2726 Global state is bad, so I went through the test suite and replaced all `Random.seed!` calls with local alternatives. The replacement depends on what the test actually needs: - If the test only checks that results are roughly correct (using `atol=` or helpers like `check_gdemo`), I used `Xoshiro` passed directly to `sample`. - If the test checks exact reproducibility (using `==` or `isequal`), I used `StableRNG` instead. - If the seed was before optimization code like MLE/MAP (which is deterministic anyway), I just deleted it. - For `MCMCThreads` tests that reset the same rng between runs, I kept the `Random.seed!(rng, seed)` pattern since that's already using an explicit rng and not global state. I left the single `Random.seed!` in `runtests.jl` untouched since that one is intentional (as discussed in the issue). --------- Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
1 parent 7072ff1 commit 184d592

File tree

7 files changed

+36
-36
lines changed

7 files changed

+36
-36
lines changed

test/ext/dynamichmc.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ using Distributions: sample
77
using DynamicHMC: DynamicHMC
88
using DynamicPPL: DynamicPPL
99
using Random: Random
10+
using StableRNGs: StableRNG
1011
using Turing
1112

1213
@testset "TuringDynamicHMCExt" begin
13-
Random.seed!(100)
14+
rng = StableRNG(100)
1415
spl = externalsampler(DynamicHMC.NUTS())
15-
chn = sample(gdemo_default, spl, 10_000)
16+
chn = sample(rng, gdemo_default, spl, 10_000)
1617
check_gdemo(chn)
1718
end
1819

test/mcmc/Inference.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import ForwardDiff
1010
using LinearAlgebra: I
1111
import MCMCChains
1212
import Random
13+
using Random: Xoshiro
1314
import ReverseDiff
1415
using StableRNGs: StableRNG
1516
using StatsFuns: logsumexp
@@ -35,11 +36,11 @@ using Turing
3536
Gibbs(:s => HMC(0.1, 5), :m => ESS()),
3637
)
3738
for sampler in samplers
38-
Random.seed!(5)
39-
chain1 = sample(model, sampler, MCMCThreads(), 10, 4)
39+
rng1 = Xoshiro(5)
40+
chain1 = sample(rng1, model, sampler, MCMCThreads(), 10, 4)
4041

41-
Random.seed!(5)
42-
chain2 = sample(model, sampler, MCMCThreads(), 10, 4)
42+
rng2 = Xoshiro(5)
43+
chain2 = sample(rng2, model, sampler, MCMCThreads(), 10, 4)
4344

4445
# For HMC, the first step does not have stats, so we need to use isequal to
4546
# avoid comparing `missing`s

test/mcmc/emcee.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@ using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
55
using Distributions: sample
66
using DynamicPPL: DynamicPPL
7-
using Random: Random
7+
using Random: Random, Xoshiro
8+
using StableRNGs: StableRNG
89
using Test: @test, @test_throws, @testset
910
using Turing
1011

1112
@testset "emcee.jl" begin
1213
@testset "gdemo" begin
13-
Random.seed!(9876)
14+
rng = StableRNG(9876)
1415

1516
n_samples = 1000
1617
n_walkers = 250
1718

1819
spl = Emcee(n_walkers, 2.0)
19-
chain = sample(gdemo_default, spl, n_samples)
20+
chain = sample(rng, gdemo_default, spl, n_samples)
2021
check_gdemo(chain)
2122
end
2223

@@ -25,18 +26,18 @@ using Turing
2526
@info "Testing emcee with large number of iterations"
2627
spl = Emcee(10, 2.0)
2728
n_samples = 10_000
28-
chain = sample(gdemo_default, spl, n_samples)
29+
chain = sample(StableRNG(5), gdemo_default, spl, n_samples)
2930
check_gdemo(chain)
3031
end
3132

3233
@testset "initial parameters" begin
3334
nwalkers = 250
3435
spl = Emcee(nwalkers, 2.0)
3536

36-
Random.seed!(1234)
37-
chain1 = sample(gdemo_default, spl, 1)
38-
Random.seed!(1234)
39-
chain2 = sample(gdemo_default, spl, 1)
37+
rng1 = Xoshiro(1234)
38+
chain1 = sample(rng1, gdemo_default, spl, 1)
39+
rng2 = Xoshiro(1234)
40+
chain2 = sample(rng2, gdemo_default, spl, 1)
4041
@test Array(chain1) == Array(chain2)
4142

4243
initial_nt = DynamicPPL.InitFromParams((s=2.0, m=1.0))

test/mcmc/gibbs.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using Distributions: InverseGamma, Normal
1313
using Distributions: sample
1414
using DynamicPPL: DynamicPPL
1515
using ForwardDiff: ForwardDiff
16-
using Random: Random
16+
using Random: Random, Xoshiro
1717
using ReverseDiff: ReverseDiff
1818
using StableRNGs: StableRNG
1919
using Test: @inferred, @test, @test_broken, @test_throws, @testset
@@ -262,10 +262,8 @@ end
262262
sampler2 = Gibbs(
263263
@varname(s) => MH(), @varname(s) => MH(), @varname(s) => MH(), @varname(m) => ESS()
264264
)
265-
Random.seed!(23)
266-
chain1 = sample(gdemo_default, sampler1, 10)
267-
Random.seed!(23)
268-
chain2 = sample(gdemo_default, sampler1, 10)
265+
chain1 = sample(Xoshiro(23), gdemo_default, sampler1, 10)
266+
chain2 = sample(Xoshiro(23), gdemo_default, sampler1, 10)
269267
@test chain1.value == chain2.value
270268
end
271269

@@ -681,8 +679,9 @@ end
681679
# Sampler to use for Gibbs components.
682680
hmc = HMC(0.1, 32)
683681
sampler = Gibbs(@varname(s) => hmc, @varname(m) => hmc)
684-
Random.seed!(42)
682+
rng = StableRNG(42)
685683
chain = sample(
684+
rng,
686685
model,
687686
sampler,
688687
MCMCThreads(),
@@ -696,8 +695,9 @@ end
696695

697696
# "Ground truth" samples.
698697
# TODO: Replace with closed-form sampling once that is implemented in DynamicPPL.
699-
Random.seed!(42)
698+
700699
chain_true = sample(
700+
StableRNG(42),
701701
model,
702702
NUTS(),
703703
MCMCThreads(),
@@ -742,8 +742,8 @@ end
742742
end
743743

744744
# `sample`
745-
Random.seed!(42)
746-
chain = sample(model, spl, 1_000; progress=false)
745+
rng = StableRNG(42)
746+
chain = sample(rng, model, spl, 1_000; progress=false)
747747
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4)
748748
end
749749

@@ -820,8 +820,8 @@ end
820820
end
821821

822822
# Sample!
823-
Random.seed!(42)
824-
chain = sample(MoGtest_default, spl, 1000; progress=false)
823+
rng = StableRNG(42)
824+
chain = sample(rng, MoGtest_default, spl, 1000; progress=false)
825825
check_MoGtest_default(chain; atol=0.2)
826826
end
827827

@@ -846,8 +846,8 @@ end
846846
end
847847

848848
# Sample!
849-
Random.seed!(42)
850-
chain = sample(model, spl, 1000; progress=false)
849+
rng = StableRNG(42)
850+
chain = sample(rng, model, spl, 1000; progress=false)
851851
check_MoGtest_default_z_vector(chain; atol=0.2)
852852
end
853853

@@ -883,9 +883,9 @@ end
883883
]
884884
@testset "$(sampler_inner)" for sampler_inner in samplers_inner
885885
sampler = Gibbs(@varname(m1) => sampler_inner, @varname(m2) => sampler_inner)
886-
Random.seed!(42)
886+
rng = StableRNG(42)
887887
chain = sample(
888-
model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
888+
rng, model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
889889
)
890890
check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
891891
check_logp_correct(sampler_inner)

test/mcmc/hmc.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ using Turing
131131
# easily make it fail, despite many more samples than taken by most other tests. Hence
132132
# explicitly specifying the seeds here.
133133
@testset "hmcda+gibbs inference" begin
134-
Random.seed!(12345)
135134
alg = Gibbs(:s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05))
136135
res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000)
137136
check_gdemo(res)

test/mcmc/particle_mcmc.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ using Turing
5555
end
5656

5757
@testset "logevidence" begin
58-
Random.seed!(100)
59-
6058
@model function test()
6159
a ~ Normal(0, 1)
6260
x ~ Bernoulli(1)
@@ -67,7 +65,7 @@ using Turing
6765
return x
6866
end
6967

70-
chains_smc = sample(test(), SMC(), 100)
68+
chains_smc = sample(StableRNG(100), test(), SMC(), 100)
7169

7270
@test all(isone, chains_smc[:x])
7371
# For SMC, the chain stores the collective logevidence of the sampled trajectories

test/stdlib/RandomMeasures.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module RandomMeasuresTests
22

33
using Distributions: Normal, sample
44
using Random: Random
5+
using StableRNGs: StableRNG
56
using Test: @test, @testset
67
using Turing
78
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
@@ -49,13 +50,12 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
4950
end
5051

5152
# Generate some test data.
52-
Random.seed!(1)
53-
data = vcat(randn(10), randn(10) .- 5, randn(10) .+ 10)
53+
rng = StableRNG(1)
54+
data = vcat(randn(rng, 10), randn(rng, 10) .- 5, randn(rng, 10) .+ 10)
5455
data .-= mean(data)
5556
data /= std(data)
5657

5758
# MCMC sampling
58-
Random.seed!(2)
5959
iterations = 500
6060
model_fun = infiniteGMM(data)
6161
chain = sample(model_fun, SMC(), iterations)

0 commit comments

Comments
 (0)