Skip to content

Commit 686a341

Browse files
authored
Merge pull request #71 from TuringLang/csp/serial
Add MCMCSerial
2 parents a089476 + 48ba9b0 commit 686a341

File tree

6 files changed

+97
-14
lines changed

6 files changed

+97
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "3.1.0"
6+
version = "3.2.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ AbstractMCMC.sample(
2828
::AbstractRNG,
2929
::AbstractMCMC.AbstractModel,
3030
::AbstractMCMC.AbstractSampler,
31-
::AbstractMCMC.AbstractMCMCParallel,
31+
::AbstractMCMC.AbstractMCMCEnsemble,
3232
::Integer,
3333
::Integer,
3434
)
3535
```
3636

37-
Two algorithms are provided for parallel sampling with multiple threads and multiple processes,
38-
respectively:
37+
Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization):
3938
```@docs
4039
AbstractMCMC.MCMCThreads
4140
AbstractMCMC.MCMCDistributed
41+
AbstractMCMC.MCMCSerial
4242
```
4343

4444
## Common keyword arguments

src/AbstractMCMC.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using StatsBase: sample
1717
export sample
1818

1919
# Parallel sampling types
20-
export MCMCThreads, MCMCDistributed
20+
export MCMCThreads, MCMCDistributed, MCMCSerial
2121

2222
"""
2323
AbstractChains
@@ -48,34 +48,43 @@ An `AbstractModel` represents a generic model type that can be used to perform i
4848
abstract type AbstractModel end
4949

5050
"""
51-
AbstractMCMCParallel
51+
AbstractMCMCEnsemble
5252
53-
An `AbstractMCMCParallel` algorithm represents a specific algorithm for sampling MCMC chains
53+
An `AbstractMCMCEnsemble` algorithm represents a specific algorithm for sampling MCMC chains
5454
in parallel.
5555
"""
56-
abstract type AbstractMCMCParallel end
56+
abstract type AbstractMCMCEnsemble end
5757

5858
"""
5959
MCMCThreads
6060
61-
The `MCMCThreads` algorithm allows to sample MCMC chains in parallel using multiple
61+
The `MCMCThreads` algorithm allows users to sample MCMC chains in parallel using multiple
6262
threads.
6363
"""
64-
struct MCMCThreads <: AbstractMCMCParallel end
64+
struct MCMCThreads <: AbstractMCMCEnsemble end
6565

6666
"""
6767
MCMCDistributed
6868
69-
The `MCMCDistributed` algorithm allows to sample MCMC chains in parallel using multiple
69+
The `MCMCDistributed` algorithm allows users to sample MCMC chains in parallel using multiple
7070
processes.
7171
"""
72-
struct MCMCDistributed <: AbstractMCMCParallel end
72+
struct MCMCDistributed <: AbstractMCMCEnsemble end
73+
74+
75+
"""
76+
MCMCSerial
77+
78+
The `MCMCSerial` algorithm allows users to sample serially, with no thread or process parallelism.
79+
"""
80+
struct MCMCSerial <: AbstractMCMCEnsemble end
7381

7482
include("samplingstats.jl")
7583
include("logging.jl")
7684
include("interface.jl")
7785
include("sample.jl")
7886
include("stepper.jl")
7987
include("transducer.jl")
88+
include("deprecations.jl")
8089

8190
end # module AbstractMCMC

src/deprecations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble
2+
Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false

src/sample.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ end
6161
function StatsBase.sample(
6262
model::AbstractModel,
6363
sampler::AbstractSampler,
64-
parallel::AbstractMCMCParallel,
64+
parallel::AbstractMCMCEnsemble,
6565
N::Integer,
6666
nchains::Integer;
6767
kwargs...
@@ -80,7 +80,7 @@ function StatsBase.sample(
8080
rng::Random.AbstractRNG,
8181
model::AbstractModel,
8282
sampler::AbstractSampler,
83-
parallel::AbstractMCMCParallel,
83+
parallel::AbstractMCMCEnsemble,
8484
N::Integer,
8585
nchains::Integer;
8686
kwargs...
@@ -444,5 +444,31 @@ function mcmcsample(
444444
return chainsstack(tighten_eltype(chains))
445445
end
446446

447+
function mcmcsample(
448+
rng::Random.AbstractRNG,
449+
model::AbstractModel,
450+
sampler::AbstractSampler,
451+
::MCMCSerial,
452+
N::Integer,
453+
nchains::Integer;
454+
progressname = "Sampling",
455+
kwargs...
456+
)
457+
# Check if the number of chains is larger than the number of samples
458+
if nchains > N
459+
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
460+
end
461+
462+
# Sample the chains.
463+
chains = map(
464+
i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
465+
kwargs...),
466+
1:nchains
467+
)
468+
469+
# Concatenate the chains together.
470+
return chainsstack(tighten_eltype(chains))
471+
end
472+
447473
tighten_eltype(x) = x
448474
tighten_eltype(x::Vector{Any}) = map(identity, x)

test/sample.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,52 @@
225225
@test all(l.level > Logging.LogLevel(-1) for l in logs)
226226
end
227227

228+
@testset "Serial sampling" begin
229+
# No dedicated chains type
230+
N = 10_000
231+
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000)
232+
@test chains isa Vector{<:Vector{<:MySample}}
233+
@test length(chains) == 1000
234+
@test all(length(x) == N for x in chains)
235+
236+
Random.seed!(1234)
237+
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000;
238+
chain_type = MyChain)
239+
240+
# Test output type and size.
241+
@test chains isa Vector{<:MyChain}
242+
@test all(c.as[1] === missing for c in chains)
243+
@test length(chains) == 1000
244+
@test all(x -> length(x.as) == length(x.bs) == N, chains)
245+
246+
# Test some statistical properties.
247+
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
248+
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
249+
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
250+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
251+
252+
# Test reproducibility.
253+
Random.seed!(1234)
254+
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000;
255+
chain_type = MyChain)
256+
257+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
258+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
259+
260+
# Unexpected order of arguments.
261+
str = "Number of chains (10) is greater than number of samples per chain (5)"
262+
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
263+
MCMCSerial(), 5, 10;
264+
chain_type = MyChain)
265+
266+
# Suppress output.
267+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
268+
sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100;
269+
progress = false, chain_type = MyChain)
270+
end
271+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
272+
end
273+
228274
@testset "Chain constructors" begin
229275
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
230276
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)

0 commit comments

Comments
 (0)