Skip to content

Commit 72c9d6c

Browse files
committed
Fix MCMCDistributed
1 parent a90226c commit 72c9d6c

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

src/sample.jl

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -544,34 +544,50 @@ function mcmcsample(
544544
pool = Distributed.CachingPool(Distributed.workers())
545545

546546
local chains
547-
# Create a channel for progress logging.
548-
if progress
549-
channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers()))
547+
548+
# Create overall progress logging object (tracks number of chains completed)
549+
overall_progress_obj = ProgressMeter.Progress(
550+
nchains; desc=progressname, dt=0.0, enabled=progress
551+
)
552+
# ProgressMeter doesn't start printing until the second iteration or so. This
553+
# forces it to start printing an empty progress bar immediately.
554+
# https://github.com/timholy/ProgressMeter.jl/issues/288
555+
ProgressMeter.update!(overall_progress_obj, 0; force=true)
556+
# Create per-chain progress logging objects
557+
progress_objs = [
558+
ProgressMeter.Progress(
559+
N; desc="Chain $i/$nchains", dt=0.01, enabled=progress, offset=i
560+
) for i in 1:nchains
561+
]
562+
for obj in progress_objs
563+
ProgressMeter.update!(obj, 0; force=true)
550564
end
565+
# Create a channel to synchronise progress updates.
566+
channel = Distributed.RemoteChannel(() -> Channel{Tuple{Int,Bool}}(), 1)
551567

552568
Distributed.@sync begin
553-
if progress
554-
# Update the progress bar.
555-
Distributed.@async begin
556-
# Determine threshold values for progress logging
557-
# (one update per 0.5% of progress)
558-
threshold = nchains ÷ 200
559-
nextprogresschains = threshold
560-
561-
progresschains = 0
562-
while take!(channel)
563-
progresschains += 1
564-
if progresschains >= nextprogresschains
565-
# ProgressLogging.@logprogress progresschains / nchains
566-
nextprogresschains = progresschains + threshold
567-
end
569+
Distributed.@async begin
570+
while true
571+
i, res = take!(channel)
572+
# i == 0 means the overall progress bar; i > 0 means the
573+
# progress bar for chain i.
574+
prog_obj = if i == 0
575+
overall_progress_obj
576+
else
577+
progress_objs[i]
578+
end
579+
if res # true = a chain / sample finished
580+
ProgressMeter.next!(prog_obj)
581+
else # false = all chains / samples finished (or one failed)
582+
ProgressMeter.finish!(prog_obj)
583+
break
568584
end
569585
end
570586
end
571587

572588
Distributed.@async begin
573589
try
574-
function sample_chain(seed, initial_params, initial_state)
590+
function sample_chain(i, seed, initial_params, initial_state)
575591
# Seed a new random number generator with the pre-made seed.
576592
Random.seed!(rng, seed)
577593

@@ -584,21 +600,23 @@ function mcmcsample(
584600
progress=false,
585601
initial_params=initial_params,
586602
initial_state=initial_state,
603+
_chain_idx=i,
604+
_progress_channel=channel,
587605
kwargs...,
588606
)
589607

590608
# Update the progress bar.
591-
progress && put!(channel, true)
609+
progress && put!(channel, (0, true))
592610

593611
# Return the new chain.
594612
return chain
595613
end
596614
chains = Distributed.pmap(
597-
sample_chain, pool, seeds, _initial_params, _initial_state
615+
sample_chain, pool, 1:nchains, seeds, _initial_params, _initial_state
598616
)
599617
finally
600618
# Stop updating the progress bar.
601-
progress && put!(channel, false)
619+
progress && put!(channel, (0, false))
602620
end
603621
end
604622
end

0 commit comments

Comments
 (0)