Skip to content

Commit b5ea802

Browse files
committed
[wip] fix parallel sampling
1 parent ececa17 commit b5ea802

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

src/logging.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
macro ifwithprogresslogger(progress, exprs...)
55
return esc(
66
quote
7-
if $progress
7+
if $progress == true
88
if $hasprogresslevel($Logging.current_logger())
99
$ProgressLogging.@withprogress $(exprs...)
1010
else

src/sample.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ function mcmcsample(
144144
@ifwithprogresslogger progress name = progressname begin
145145
# Determine threshold values for progress logging
146146
# (one update per 0.5% of progress)
147-
if progress
147+
if (progress == true || progress === nothing)
148148
threshold = Ntotal ÷ 200
149149
next_update = threshold
150150
end
@@ -166,8 +166,12 @@ function mcmcsample(
166166

167167
# Update the progress bar.
168168
itotal = 1
169-
if progress && itotal >= next_update
170-
ProgressLogging.@logprogress itotal / Ntotal
169+
if !(progress == false) && itotal >= next_update
170+
if progress == true
171+
ProgressLogging.@logprogress itotal / Ntotal
172+
else
173+
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
174+
end
171175
next_update = itotal + threshold
172176
end
173177

@@ -181,8 +185,12 @@ function mcmcsample(
181185
end
182186

183187
# Update the progress bar.
184-
if progress && (itotal += 1) >= next_update
185-
ProgressLogging.@logprogress itotal / Ntotal
188+
if !(progress == false) && (itotal += 1) >= next_update
189+
if progress == true
190+
ProgressLogging.@logprogress itotal / Ntotal
191+
else
192+
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
193+
end
186194
next_update = itotal + threshold
187195
end
188196
end
@@ -206,8 +214,12 @@ function mcmcsample(
206214
end
207215

208216
# Update progress bar.
209-
if progress && (itotal += 1) >= next_update
210-
ProgressLogging.@logprogress itotal / Ntotal
217+
if !(progress == false) && (itotal += 1) >= next_update
218+
if progress == true
219+
ProgressLogging.@logprogress itotal / Ntotal
220+
else
221+
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
222+
end
211223
next_update = itotal + threshold
212224
end
213225
end
@@ -227,8 +239,12 @@ function mcmcsample(
227239
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
228240

229241
# Update the progress bar.
230-
if progress && (itotal += 1) >= next_update
231-
ProgressLogging.@logprogress itotal / Ntotal
242+
if !(progress == false) && (itotal += 1) >= next_update
243+
if progress == true
244+
ProgressLogging.@logprogress itotal / Ntotal
245+
else
246+
ProgressLogging.@logprogress itotal / Ntotal _id = "hello"
247+
end
232248
next_update = itotal + threshold
233249
end
234250
end
@@ -456,12 +472,17 @@ function mcmcsample(
456472
Random.seed!(_rng, seeds[chainidx])
457473

458474
# Sample a chain and save it to the vector.
459-
chains[chainidx] = StatsBase.sample(
475+
child_progress = if progress == false
476+
false
477+
else
478+
nothing
479+
end
480+
@ifwithprogresslogger progress chains[chainidx] = StatsBase.sample(
460481
_rng,
461482
_model,
462483
_sampler,
463484
N;
464-
progress=false,
485+
progress=child_progress,
465486
initial_params=if initial_params === nothing
466487
nothing
467488
else

0 commit comments

Comments
 (0)