Skip to content

Commit d46ba93

Browse files
authored
Merge pull request #97 from TuringLang/dw/reproducible
Ensure ensemble sampling is reproducible
2 parents 6dae58b + 22299e2 commit d46ba93

File tree

4 files changed

+107
-74
lines changed

4 files changed

+107
-74
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ jobs:
4646
version: ${{ matrix.version }}
4747
arch: ${{ matrix.arch }}
4848
- uses: julia-actions/cache@v1
49+
with:
50+
cache-packages: "false" # caching Conda.jl causes precompilation error
4951
- uses: julia-actions/julia-buildpkg@latest
5052
- uses: julia-actions/julia-runtest@latest
5153
env:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "3.3.0"
6+
version = "3.3.2"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ function mcmcsample(
305305
models = [deepcopy(model) for _ in interval]
306306
samplers = [deepcopy(sampler) for _ in interval]
307307

308-
# Create a seed for each chunk using the provided random number generator.
309-
seeds = rand(rng, UInt, nchunks)
308+
# Create a seed for each chain using the provided random number generator.
309+
seeds = rand(rng, UInt, nchains)
310310

311311
# Set up a chains vector.
312312
chains = Vector{Any}(undef, nchains)
@@ -339,25 +339,22 @@ function mcmcsample(
339339

340340
Distributed.@async begin
341341
try
342-
Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers)
343-
Threads.@spawn begin
342+
Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers)
343+
chainidxs = if i == nchunks
344+
((i - 1) * chunksize + 1):nchains
345+
else
346+
((i - 1) * chunksize + 1):(i * chunksize)
347+
end
348+
Threads.@spawn for chainidx in chainidxs
344349
# Seed the chunk-specific random number generator with the pre-made seed.
345-
Random.seed!(_rng, seed)
346-
347-
chainidxs = if i == nchunks
348-
((i - 1) * chunksize + 1):nchains
349-
else
350-
((i - 1) * chunksize + 1):(i * chunksize)
351-
end
352-
353-
for chainidx in chainidxs
354-
# Sample a chain and save it to the vector.
355-
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
356-
progress = false, kwargs...)
357-
358-
# Update the progress bar.
359-
progress && put!(channel, true)
360-
end
350+
Random.seed!(_rng, seeds[chainidx])
351+
352+
# Sample a chain and save it to the vector.
353+
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
354+
progress = false, kwargs...)
355+
356+
# Update the progress bar.
357+
progress && put!(channel, true)
361358
end
362359
end
363360
finally
@@ -469,12 +466,18 @@ function mcmcsample(
469466
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
470467
end
471468

469+
# Create a seed for each chain using the provided random number generator.
470+
seeds = rand(rng, UInt, nchains)
471+
472472
# Sample the chains.
473-
chains = map(
474-
i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
475-
kwargs...),
476-
1:nchains
477-
)
473+
chains = map(enumerate(seeds)) do (i, seed)
474+
Random.seed!(rng, seed)
475+
return StatsBase.sample(
476+
rng, model, sampler, N;
477+
progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
478+
kwargs...,
479+
)
480+
end
478481

479482
# Concatenate the chains together.
480483
return chainsstack(tighten_eltype(chains))

test/sample.jl

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -103,61 +103,59 @@
103103
end
104104
end
105105

106-
if VERSION v"1.3"
107-
@testset "Multithreaded sampling" begin
108-
if Threads.nthreads() == 1
109-
warnregex = r"^Only a single thread available"
110-
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
111-
10, 10)
112-
end
106+
@testset "Multithreaded sampling" begin
107+
if Threads.nthreads() == 1
108+
warnregex = r"^Only a single thread available"
109+
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
110+
10, 10)
111+
end
113112

114-
# No dedicated chains type
115-
N = 10_000
116-
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000)
117-
@test chains isa Vector{<:Vector{<:MySample}}
118-
@test length(chains) == 1000
119-
@test all(length(x) == N for x in chains)
113+
# No dedicated chains type
114+
N = 10_000
115+
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000)
116+
@test chains isa Vector{<:Vector{<:MySample}}
117+
@test length(chains) == 1000
118+
@test all(length(x) == N for x in chains)
120119

121-
Random.seed!(1234)
122-
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
123-
chain_type = MyChain)
120+
Random.seed!(1234)
121+
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
122+
chain_type = MyChain)
124123

125-
# test output type and size
126-
@test chains isa Vector{<:MyChain}
127-
@test length(chains) == 1000
128-
@test all(x -> length(x.as) == length(x.bs) == N, chains)
124+
# test output type and size
125+
@test chains isa Vector{<:MyChain}
126+
@test length(chains) == 1000
127+
@test all(x -> length(x.as) == length(x.bs) == N, chains)
129128

130-
# test some statistical properties
131-
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
132-
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
133-
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
134-
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
129+
# test some statistical properties
130+
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
131+
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
132+
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
133+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains)
135134

136-
# test reproducibility
137-
Random.seed!(1234)
138-
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
139-
chain_type = MyChain)
135+
# test reproducibility
136+
Random.seed!(1234)
137+
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
138+
chain_type = MyChain)
140139

141-
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
142-
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
140+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
141+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
143142

144-
# Unexpected order of arguments.
145-
str = "Number of chains (10) is greater than number of samples per chain (5)"
146-
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
147-
MCMCThreads(), 5, 10;
148-
chain_type = MyChain)
143+
# Unexpected order of arguments.
144+
str = "Number of chains (10) is greater than number of samples per chain (5)"
145+
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
146+
MCMCThreads(), 5, 10;
147+
chain_type = MyChain)
149148

150-
# Suppress output.
151-
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
152-
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
153-
progress = false, chain_type = MyChain)
154-
end
155-
@test all(l.level > Logging.LogLevel(-1) for l in logs)
149+
# Suppress output.
150+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
151+
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
152+
progress = false, chain_type = MyChain)
153+
end
154+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
156155

157-
# Smoke test for nchains < nthreads
158-
if Threads.nthreads() == 2
159-
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
160-
end
156+
# Smoke test for nchains < nthreads
157+
if Threads.nthreads() == 2
158+
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
161159
end
162160
end
163161

@@ -201,7 +199,7 @@
201199
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
202200
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
203201
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
204-
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
202+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains)
205203

206204
# Test reproducibility.
207205
Random.seed!(1234)
@@ -247,7 +245,7 @@
247245
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
248246
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
249247
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
250-
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
248+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains)
251249

252250
# Test reproducibility.
253251
Random.seed!(1234)
@@ -271,6 +269,36 @@
271269
@test all(l.level > Logging.LogLevel(-1) for l in logs)
272270
end
273271

272+
@testset "Ensemble sampling: Reproducibility" begin
273+
N = 1_000
274+
nchains = 10
275+
276+
# Serial sampling
277+
Random.seed!(1234)
278+
chains_serial = sample(
279+
MyModel(), MySampler(), MCMCSerial(), N, nchains;
280+
progress=false, chain_type=MyChain
281+
)
282+
283+
# Multi-threaded sampling
284+
Random.seed!(1234)
285+
chains_threads = sample(
286+
MyModel(), MySampler(), MCMCThreads(), N, nchains;
287+
progress=false, chain_type=MyChain
288+
)
289+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N)
290+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N)
291+
292+
# Multi-core sampling
293+
Random.seed!(1234)
294+
chains_distributed = sample(
295+
MyModel(), MySampler(), MCMCDistributed(), N, nchains;
296+
progress=false, chain_type=MyChain
297+
)
298+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N)
299+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N)
300+
end
301+
274302
@testset "Chain constructors" begin
275303
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
276304
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)

0 commit comments

Comments
 (0)