Skip to content

Commit f5a6535

Browse files
committed
Parallel sampling with ProgressLogging
1 parent b5ea802 commit f5a6535

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1818
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1919
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2020
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
21+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2122

2223
[compat]
2324
BangBang = "0.3.19, 0.4"
@@ -29,6 +30,7 @@ ProgressLogging = "0.1"
2930
StatsBase = "0.32, 0.33, 0.34"
3031
TerminalLoggers = "0.1"
3132
Transducers = "0.4.30"
33+
UUIDs = "1.11.0"
3234
julia = "1.6"
3335

3436
[extras]

src/sample.jl

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using UUIDs: uuid4
2+
13
# Default implementations of `sample`.
24
const PROGRESS = Ref(true)
35

@@ -144,11 +146,22 @@ function mcmcsample(
144146
@ifwithprogresslogger progress name = progressname begin
145147
# Determine threshold values for progress logging
146148
# (one update per 0.5% of progress)
147-
if (progress == true || progress === nothing)
149+
if !(progress == false)
148150
threshold = Ntotal ÷ 200
149151
next_update = threshold
150152
end
151153

154+
# Ugly hacky code to reset the start timer if called from a multi-chain
155+
# sampling process
156+
if progress isa ProgressLogging.Progress
157+
try
158+
bartrees = Logging.current_logger().loggers[1].logger.bartrees
159+
bar = TerminalLoggers.findbar(bartrees, progress.id).data
160+
bar.tfirst = time()
161+
catch
162+
end
163+
end
164+
152165
# Obtain the initial sample and state.
153166
sample, state = if num_warmup > 0
154167
if initial_state === nothing
@@ -170,7 +183,8 @@ function mcmcsample(
170183
if progress == true
171184
ProgressLogging.@logprogress itotal / Ntotal
172185
else
173-
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
186+
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
187+
progress.id
174188
end
175189
next_update = itotal + threshold
176190
end
@@ -189,7 +203,8 @@ function mcmcsample(
189203
if progress == true
190204
ProgressLogging.@logprogress itotal / Ntotal
191205
else
192-
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
206+
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
207+
progress.id
193208
end
194209
next_update = itotal + threshold
195210
end
@@ -218,7 +233,8 @@ function mcmcsample(
218233
if progress == true
219234
ProgressLogging.@logprogress itotal / Ntotal
220235
else
221-
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
236+
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
237+
progress.id
222238
end
223239
next_update = itotal + threshold
224240
end
@@ -243,7 +259,8 @@ function mcmcsample(
243259
if progress == true
244260
ProgressLogging.@logprogress itotal / Ntotal
245261
else
246-
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
262+
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
263+
progress.id
247264
end
248265
next_update = itotal + threshold
249266
end
@@ -432,6 +449,18 @@ function mcmcsample(
432449
if progress
433450
channel = Channel{Bool}(length(interval))
434451
end
452+
# Generate nchains independent UUIDs for each progress bar
453+
uuids = [uuid4() for _ in 1:nchains]
454+
# Start the progress bars (but in reverse order, because
455+
# ProgressLogging prints from the bottom up, and we want chain 1 to
456+
# show up at the top)
457+
# TODO: This has an unintended effect that the 'time' field in the
458+
# progress bar shows the total time since the beginning of sampling,
459+
# even if the specific chain doesn't start sampling until later on.
460+
for (i, uuid) in enumerate(reverse(uuids))
461+
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
462+
uuid
463+
end
435464

436465
Distributed.@sync begin
437466
if progress
@@ -472,17 +501,21 @@ function mcmcsample(
472501
Random.seed!(_rng, seeds[chainidx])
473502

474503
# Sample a chain and save it to the vector.
504+
child_progressname = "Chain $chainidx/$nchains"
475505
child_progress = if progress == false
476506
false
477507
else
478-
nothing
508+
ProgressLogging.Progress(
509+
uuids[chainidx]; name=child_progressname
510+
)
479511
end
480-
@ifwithprogresslogger progress chains[chainidx] = StatsBase.sample(
512+
chains[chainidx] = StatsBase.sample(
481513
_rng,
482514
_model,
483515
_sampler,
484516
N;
485517
progress=child_progress,
518+
progressname=child_progressname,
486519
initial_params=if initial_params === nothing
487520
nothing
488521
else
@@ -496,6 +529,8 @@ function mcmcsample(
496529
kwargs...,
497530
)
498531

532+
ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]
533+
499534
# Update the progress bar.
500535
progress && put!(channel, true)
501536
end

0 commit comments

Comments
 (0)