Skip to content

Commit 0b96004

Browse files
committed
Add MCMCSerial
1 parent a089476 commit 0b96004

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

src/AbstractMCMC.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,27 @@ abstract type AbstractMCMCParallel end
5858
"""
5959
MCMCThreads
6060
61-
The `MCMCThreads` algorithm allows to sample MCMC chains in parallel using multiple
61+
The `MCMCThreads` algorithm allows users to sample MCMC chains in parallel using multiple
6262
threads.
6363
"""
6464
struct MCMCThreads <: AbstractMCMCParallel end
6565

6666
"""
6767
MCMCDistributed
6868
69-
The `MCMCDistributed` algorithm allows to sample MCMC chains in parallel using multiple
69+
The `MCMCDistributed` algorithm allows users to sample MCMC chains in parallel using multiple
7070
processes.
7171
"""
7272
struct MCMCDistributed <: AbstractMCMCParallel end
7373

74+
75+
"""
76+
MCMCSerial
77+
78+
The `MCMCSerial` algorithm allows users to sample serially, with no thread or process parallelism.
79+
"""
80+
struct MCMCSerial <: AbstractMCMCParallel end
81+
7482
include("samplingstats.jl")
7583
include("logging.jl")
7684
include("interface.jl")

src/sample.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,5 +444,35 @@ function mcmcsample(
444444
return chainsstack(tighten_eltype(chains))
445445
end
446446

447+
function mcmcsample(
448+
rng::Random.AbstractRNG,
449+
model::AbstractModel,
450+
sampler::AbstractSampler,
451+
::MCMCSerial,
452+
N::Integer,
453+
nchains::Integer;
454+
progressname = "Sampling",
455+
kwargs...
456+
)
457+
# Check if the number of chains is larger than the number of samples
458+
if nchains > N
459+
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
460+
end
461+
462+
# Set up a chains vector.
463+
chains = Vector{Any}(undef, nchains)
464+
465+
# Sample each chain
466+
for i in 1:nchains
467+
# Sample a chain and save it to the vector.
468+
chains[i] = StatsBase.sample(rng, model, sampler, N;
469+
progressname = string(progressname, " (Chain $i of $nchains)"),
470+
kwargs...)
471+
end
472+
473+
# Concatenate the chains together.
474+
return chainsstack(tighten_eltype(chains))
475+
end
476+
447477
tighten_eltype(x) = x
448478
tighten_eltype(x::Vector{Any}) = map(identity, x)

test/sample.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,52 @@
225225
@test all(l.level > Logging.LogLevel(-1) for l in logs)
226226
end
227227

228+
@testset "Serial sampling" begin
229+
# No dedicated chains type
230+
N = 10_000
231+
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000)
232+
@test chains isa Vector{<:Vector{<:MySample}}
233+
@test length(chains) == 1000
234+
@test all(length(x) == N for x in chains)
235+
236+
Random.seed!(1234)
237+
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000;
238+
chain_type = MyChain)
239+
240+
# Test output type and size.
241+
@test chains isa Vector{<:MyChain}
242+
@test all(c.as[1] === missing for c in chains)
243+
@test length(chains) == 1000
244+
@test all(x -> length(x.as) == length(x.bs) == N, chains)
245+
246+
# Test some statistical properties.
247+
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
248+
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
249+
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
250+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
251+
252+
# Test reproducibility.
253+
Random.seed!(1234)
254+
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000;
255+
chain_type = MyChain)
256+
257+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
258+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
259+
260+
# Unexpected order of arguments.
261+
str = "Number of chains (10) is greater than number of samples per chain (5)"
262+
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
263+
MCMCSerial(), 5, 10;
264+
chain_type = MyChain)
265+
266+
# Suppress output.
267+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
268+
sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100;
269+
progress = false, chain_type = MyChain)
270+
end
271+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
272+
end
273+
228274
@testset "Chain constructors" begin
229275
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
230276
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)

0 commit comments

Comments
 (0)