Skip to content

Commit 7b6416f

Browse files
authored
Merge pull request #33 from TuringLang/distributed
Add multicore sampling
2 parents b003270 + bea35ec commit 7b6416f

File tree

3 files changed

+211
-34
lines changed

3 files changed

+211
-34
lines changed

src/AbstractMCMC.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@ import ConsoleProgressMonitor
44
import LoggingExtras
55
import ProgressLogging
66
import StatsBase
7-
using StatsBase: sample
87
import TerminalLoggers
98

109
import Distributed
1110
import Logging
1211
import Random
1312

13+
# Reexport sample
14+
using StatsBase: sample
15+
export sample
16+
17+
# Parallel sampling types
18+
export MCMCThreads, MCMCDistributed
19+
1420
"""
1521
AbstractChains
1622
@@ -39,6 +45,30 @@ An `AbstractModel` represents a generic model type that can be used to perform i
3945
"""
4046
abstract type AbstractModel end
4147

48+
"""
49+
AbstractMCMCParallel
50+
51+
An `AbstractMCMCParallel` algorithm represents a specific algorithm for sampling MCMC chains
52+
in parallel.
53+
"""
54+
abstract type AbstractMCMCParallel end
55+
56+
"""
57+
MCMCThreads
58+
59+
The `MCMCThreads` algorithm allows to sample MCMC chains in parallel using multiple
60+
threads.
61+
"""
62+
struct MCMCThreads <: AbstractMCMCParallel end
63+
64+
"""
65+
MCMCDistributed
66+
67+
The `MCMCDistributed` algorithm allows to sample MCMC chains in parallel using multiple
68+
processes.
69+
"""
70+
struct MCMCDistributed <: AbstractMCMCParallel end
71+
4272
include("logging.jl")
4373
include("interface.jl")
4474
include("sample.jl")

src/sample.jl

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Default implementations of `sample` and `psample`.
1+
# Default implementations of `sample`.
22

33
function StatsBase.sample(
44
model::AbstractModel,
@@ -19,25 +19,28 @@ function StatsBase.sample(
1919
return mcmcsample(rng, model, sampler, arg; kwargs...)
2020
end
2121

22-
function psample(
22+
function StatsBase.sample(
2323
model::AbstractModel,
2424
sampler::AbstractSampler,
25+
parallel::AbstractMCMCParallel,
2526
N::Integer,
2627
nchains::Integer;
2728
kwargs...
2829
)
29-
return psample(Random.GLOBAL_RNG, model, sampler, N, nchains; kwargs...)
30+
return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, parallel, N, nchains;
31+
kwargs...)
3032
end
3133

32-
function psample(
34+
function StatsBase.sample(
3335
rng::Random.AbstractRNG,
3436
model::AbstractModel,
3537
sampler::AbstractSampler,
38+
parallel::AbstractMCMCParallel,
3639
N::Integer,
3740
nchains::Integer;
3841
kwargs...
3942
)
40-
return mcmcpsample(rng, model, sampler, N, nchains; kwargs...)
43+
return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...)
4144
end
4245

4346
# Default implementations of regular and parallel sampling.
@@ -173,23 +176,27 @@ function mcmcsample(
173176
end
174177

175178
"""
176-
mcmcpsample([rng, ]model, sampler, N, nchains; kwargs...)
179+
mcmcsample([rng, ]model, sampler, parallel, N, nchains; kwargs...)
177180
178-
Sample `nchains` chains using the available threads, and combine them into a single chain.
179-
180-
By default, the random number generator, the model and the samplers are deep copied for each
181-
thread to prevent contamination between threads.
181+
Sample `nchains` chains in parallel using the `parallel` algorithm, and combine them into a
182+
single chain.
182183
"""
183-
function mcmcpsample(
184+
function mcmcsample(
184185
rng::Random.AbstractRNG,
185186
model::AbstractModel,
186187
sampler::AbstractSampler,
188+
::MCMCThreads,
187189
N::Integer,
188190
nchains::Integer;
189191
progress = true,
190-
progressname = "Parallel sampling",
192+
progressname = "Sampling ($(Threads.nthreads()) threads)",
191193
kwargs...
192194
)
195+
# Check if actually multiple threads are used.
196+
if Threads.nthreads() == 1
197+
@warn "Only a single thread available: MCMC chains are not sampled in parallel"
198+
end
199+
193200
# Copy the random number generator, model, and sample for each thread
194201
rngs = [deepcopy(rng) for _ in 1:Threads.nthreads()]
195202
models = [deepcopy(model) for _ in 1:Threads.nthreads()]
@@ -204,7 +211,7 @@ function mcmcpsample(
204211
@ifwithprogresslogger progress name=progressname begin
205212
# Create a channel for progress logging.
206213
if progress
207-
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
214+
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains))
208215
end
209216

210217
Distributed.@sync begin
@@ -245,3 +252,77 @@ function mcmcpsample(
245252
# Concatenate the chains together.
246253
return reduce(chainscat, chains)
247254
end
255+
256+
function mcmcsample(
257+
rng::Random.AbstractRNG,
258+
model::AbstractModel,
259+
sampler::AbstractSampler,
260+
::MCMCDistributed,
261+
N::Integer,
262+
nchains::Integer;
263+
progress = true,
264+
progressname = "Sampling ($(Distributed.nworkers()) processes)",
265+
kwargs...
266+
)
267+
# Check if actually multiple processes are used.
268+
if Distributed.nworkers() == 1
269+
@warn "Only a single process available: MCMC chains are not sampled in parallel"
270+
end
271+
272+
# Create a seed for each chain using the provided random number generator.
273+
seeds = rand(rng, UInt, nchains)
274+
275+
# Set up worker pool.
276+
pool = Distributed.CachingPool(Distributed.workers())
277+
278+
# Create a channel for progress logging.
279+
channel = progress ? Distributed.RemoteChannel(() -> Channel{Bool}(nchains)) : nothing
280+
281+
local chains
282+
@ifwithprogresslogger progress name=progressname begin
283+
Distributed.@sync begin
284+
# Update the progress bar.
285+
if progress
286+
Distributed.@async begin
287+
progresschains = 0
288+
while take!(channel)
289+
progresschains += 1
290+
ProgressLogging.@logprogress progresschains/nchains
291+
end
292+
end
293+
end
294+
295+
Distributed.@async begin
296+
chains = let rng=rng, model=model, sampler=sampler, N=N, channel=channel,
297+
kwargs=kwargs
298+
Distributed.pmap(pool, seeds) do seed
299+
# Seed a new random number generator with the pre-made seed.
300+
subrng = deepcopy(rng)
301+
Random.seed!(subrng, seed)
302+
303+
# Sample a chain.
304+
chain = StatsBase.sample(subrng, model, sampler, N;
305+
progress = false, kwargs...)
306+
307+
# Update the progress bar.
308+
channel === nothing || put!(channel, true)
309+
310+
# Return the new chain.
311+
return chain
312+
end
313+
end
314+
315+
# Stop updating the progress bar.
316+
progress && put!(channel, false)
317+
end
318+
end
319+
end
320+
321+
# Concatenate the chains together.
322+
return reduce(chainscat, chains)
323+
end
324+
325+
# Deprecations.
326+
Base.@deprecate psample(model, sampler, N, nchains; kwargs...) sample(model, sampler, MCMCThreads(), N, nchains; kwargs...) false
327+
Base.@deprecate psample(rng, model, sampler, N, nchains; kwargs...) sample(rng, model, sampler, MCMCThreads(), N, nchains; kwargs...) false
328+
Base.@deprecate mcmcpsample(rng, model, sampler, N, nchains; kwargs...) mcmcsample(rng, model, sampler, MCMCThreads(), N, nchains; kwargs...) false

test/runtests.jl

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
using AbstractMCMC
2-
using AbstractMCMC: sample, psample, steps!
2+
using AbstractMCMC: steps!
33
using Atom.Progress: JunoProgressLogger
44
using ConsoleProgressMonitor: ProgressLogger
55
using IJulia
66
using LoggingExtras: TeeLogger, EarlyFilteredLogger
77
using TerminalLoggers: TerminalLogger
88

9+
using Distributed
910
import Logging
1011
using Random
1112
using Statistics
@@ -99,14 +100,26 @@ include("interface.jl")
99100
@test first(LOGGERS) === logger
100101
@test Logging.current_logger() === CURRENT_LOGGER
101102
end
103+
104+
@testset "Suppress output" begin
105+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
106+
sample(MyModel(), MySampler(), 100; progress = false, sleepy = true)
107+
end
108+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
109+
end
102110
end
103111

104112
if VERSION v"1.3"
105-
@testset "Parallel sampling" begin
106-
println("testing parallel sampling with ", Threads.nthreads(), " thread(s)...")
113+
@testset "Multithreaded sampling" begin
114+
if Threads.nthreads() == 1
115+
warnregex = r"^Only a single thread available"
116+
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
117+
10, 10; chain_type = MyChain)
118+
end
107119

108120
Random.seed!(1234)
109-
chains = psample(MyModel(), MySampler(), 10_000, 1000; chain_type = MyChain)
121+
chains = sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
122+
chain_type = MyChain)
110123

111124
# test output type and size
112125
@test chains isa Vector{MyChain}
@@ -121,12 +134,69 @@ include("interface.jl")
121134

122135
# test reproducibility
123136
Random.seed!(1234)
124-
chains2 = psample(MyModel(), MySampler(), 10_000, 1000; chain_type = MyChain)
137+
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
138+
chain_type = MyChain)
125139

126140
@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
141+
142+
# Suppress output.
143+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
144+
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
145+
progress = false, chain_type = MyChain)
146+
end
147+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
127148
end
128149
end
129150

151+
@testset "Multicore sampling" begin
152+
if nworkers() == 1
153+
warnregex = r"^Only a single process available"
154+
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCDistributed(),
155+
10, 10; chain_type = MyChain)
156+
end
157+
158+
# Add worker processes.
159+
addprocs()
160+
161+
# Load all required packages (`interface.jl` needs Random).
162+
@everywhere begin
163+
using AbstractMCMC
164+
using AbstractMCMC: sample
165+
166+
using Random
167+
include("interface.jl")
168+
end
169+
170+
Random.seed!(1234)
171+
chains = sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 1000;
172+
chain_type = MyChain)
173+
174+
# Test output type and size.
175+
@test chains isa Vector{MyChain}
176+
@test length(chains) == 1000
177+
@test all(x -> length(x.as) == length(x.bs) == 10_000, chains)
178+
179+
# Test some statistical properties.
180+
@test all(x -> isapprox(mean(x.as), 0.5; atol=1e-2), chains)
181+
@test all(x -> isapprox(var(x.as), 1 / 12; atol=5e-3), chains)
182+
@test all(x -> isapprox(mean(x.bs), 0; atol=5e-2), chains)
183+
@test all(x -> isapprox(var(x.bs), 1; atol=5e-2), chains)
184+
185+
# Test reproducibility.
186+
Random.seed!(1234)
187+
chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 1000;
188+
chain_type = MyChain)
189+
190+
@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
191+
192+
# Suppress output.
193+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
194+
sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100;
195+
progress = false, chain_type = MyChain)
196+
end
197+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
198+
end
199+
130200
@testset "Chain constructors" begin
131201
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
132202
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)
@@ -135,21 +205,6 @@ include("interface.jl")
135205
@test chain2 isa MyChain
136206
end
137207

138-
@testset "Suppress output" begin
139-
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
140-
sample(MyModel(), MySampler(), 100; progress = false, sleepy = true)
141-
end
142-
@test isempty(logs)
143-
144-
if VERSION v"1.3"
145-
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
146-
psample(MyModel(), MySampler(), 10_000, 1000;
147-
progress = false, chain_type = MyChain)
148-
end
149-
@test isempty(logs)
150-
end
151-
end
152-
153208
@testset "Iterator sampling" begin
154209
Random.seed!(1234)
155210
as = []
@@ -182,4 +237,15 @@ include("interface.jl")
182237
bmean = mean(x.b for x in chain)
183238
@test abs(bmean) <= 0.001 && length(chain) < 10_000
184239
end
240+
241+
@testset "Deprecations" begin
242+
@test_deprecated AbstractMCMC.psample(MyModel(), MySampler(), 10, 10;
243+
chain_type = MyChain)
244+
@test_deprecated AbstractMCMC.psample(Random.GLOBAL_RNG, MyModel(), MySampler(),
245+
10, 10;
246+
chain_type = MyChain)
247+
@test_deprecated AbstractMCMC.mcmcpsample(Random.GLOBAL_RNG, MyModel(),
248+
MySampler(), 10, 10;
249+
chain_type = MyChain)
250+
end
185251
end

0 commit comments

Comments
 (0)