1
1
module AbstractMCMC
2
2
3
+ import ConsoleProgressMonitor
4
+ import LoggingExtras
3
5
import ProgressLogging
4
6
import StatsBase
5
7
using StatsBase: sample
8
+ import TerminalLoggers
6
9
7
10
import Distributed
8
11
import Logging
9
12
using Random: GLOBAL_RNG, AbstractRNG, seed!
10
- import UUIDs
13
+
14
+ # avoid creating a progress bar with @withprogress if progress logging is disabled
15
+ # and add a custom progress logger if the current logger does not seem to be able to handle
16
+ # progress logs
17
+ macro ifwithprogresslogger (progress, exprs... )
18
+ return quote
19
+ if $ progress
20
+ if $ hasprogresslevel ($ Logging. current_logger ())
21
+ $ ProgressLogging. @withprogress $ (exprs... )
22
+ else
23
+ $ with_progresslogger ($ Base. @__MODULE__ , $ Logging. current_logger ()) do
24
+ $ ProgressLogging. @withprogress $ (exprs... )
25
+ end
26
+ end
27
+ else
28
+ $ (exprs[end ])
29
+ end
30
+ end |> esc
31
+ end
32
+
33
+ # improved checks?
34
+ function hasprogresslevel (logger)
35
+ return Logging. min_enabled_level (logger) ≤ ProgressLogging. ProgressLevel
36
+ end
37
+
38
+ # filter better, e.g., according to group?
39
+ function with_progresslogger (f, _module, logger)
40
+ logger1 = LoggingExtras. EarlyFilteredLogger (progresslogger ()) do log
41
+ log. _module === _module && log. level == ProgressLogging. ProgressLevel
42
+ end
43
+ logger2 = LoggingExtras. EarlyFilteredLogger (logger) do log
44
+ log. _module != = _module || log. level != ProgressLogging. ProgressLevel
45
+ end
46
+
47
+ Logging. with_logger (f, LoggingExtras. TeeLogger (logger1, logger2))
48
+ end
49
+
50
+ function progresslogger ()
51
+ # detect if code is running under IJulia since TerminalLogger does not work with IJulia
52
+ # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
53
+ if isdefined (Main, :IJulia ) && Main. IJulia. inited
54
+ return ConsoleProgressMonitor. ProgressLogger ()
55
+ else
56
+ return TerminalLoggers. TerminalLogger ()
57
+ end
58
+ end
11
59
12
60
"""
13
61
AbstractChains
@@ -44,7 +92,7 @@ abstract type AbstractModel end
44
92
45
93
Return `N` samples from the MCMC `sampler` for the provided `model`.
46
94
47
- If a callback function `f` with type signature
95
+ If a callback function `f` with type signature
48
96
```julia
49
97
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
50
98
iteration::Integer, transition; kwargs...)
@@ -77,15 +125,7 @@ function StatsBase.sample(
77
125
# Perform any necessary setup.
78
126
sample_init! (rng, model, sampler, N; kwargs... )
79
127
80
- # Create a progress bar.
81
- if progress
82
- progressid = UUIDs. uuid4 ()
83
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= NaN ,
84
- _id= progressid)
85
- end
86
-
87
- local transitions
88
- try
128
+ @ifwithprogresslogger progress name= progressname begin
89
129
# Obtain the initial transition.
90
130
transition = step! (rng, model, sampler, N; iteration= 1 , kwargs... )
91
131
@@ -97,10 +137,7 @@ function StatsBase.sample(
97
137
transitions_save! (transitions, 1 , transition, model, sampler, N; kwargs... )
98
138
99
139
# Update the progress bar.
100
- if progress
101
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= 1 / N,
102
- _id= progressid)
103
- end
140
+ progress && ProgressLogging. @logprogress 1 / N
104
141
105
142
# Step through the sampler.
106
143
for i in 2 : N
@@ -114,16 +151,7 @@ function StatsBase.sample(
114
151
transitions_save! (transitions, i, transition, model, sampler, N; kwargs... )
115
152
116
153
# Update the progress bar.
117
- if progress
118
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= i/ N,
119
- _id= progressid)
120
- end
121
- end
122
- finally
123
- # Close the progress bar.
124
- if progress
125
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= " done" ,
126
- _id= progressid)
154
+ progress && ProgressLogging. @logprogress i/ N
127
155
end
128
156
end
129
157
@@ -178,12 +206,12 @@ function sample_end!(
178
206
end
179
207
180
208
function bundle_samples (
181
- :: AbstractRNG ,
182
- :: AbstractModel ,
183
- :: AbstractSampler ,
184
- :: Integer ,
209
+ :: AbstractRNG ,
210
+ :: AbstractModel ,
211
+ :: AbstractSampler ,
212
+ :: Integer ,
185
213
transitions,
186
- :: Type{Any} ;
214
+ :: Type{Any} ;
187
215
kwargs...
188
216
)
189
217
return transitions
259
287
Sample `nchains` chains using the available threads, and combine them into a single chain.
260
288
261
289
By default, the random number generator, the model and the samplers are deep copied for each
262
- thread to prevent contamination between threads.
290
+ thread to prevent contamination between threads.
263
291
"""
264
292
function psample (
265
293
model:: AbstractModel ,
@@ -292,24 +320,20 @@ function psample(
292
320
# Set up a chains vector.
293
321
chains = Vector {Any} (undef, nchains)
294
322
295
- # Create a progress bar and a channel for progress logging.
296
- if progress
297
- progressid = UUIDs. uuid4 ()
298
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= NaN ,
299
- _id= progressid)
300
- channel = Distributed. RemoteChannel (() -> Channel {Bool} (nchains), 1 )
301
- end
323
+ @ifwithprogresslogger progress name= progressname begin
324
+ # Create a channel for progress logging.
325
+ if progress
326
+ channel = Distributed. RemoteChannel (() -> Channel {Bool} (nchains), 1 )
327
+ end
302
328
303
- try
304
329
Distributed. @sync begin
305
330
if progress
306
331
Distributed. @async begin
307
332
# Update the progress bar.
308
333
progresschains = 0
309
334
while take! (channel)
310
335
progresschains += 1
311
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname,
312
- progress= progresschains/ nchains, _id= progressid)
336
+ ProgressLogging. @logprogress progresschains/ nchains
313
337
end
314
338
end
315
339
end
@@ -322,7 +346,7 @@ function psample(
322
346
# Seed the thread-specific random number generator with the pre-made seed.
323
347
subrng = rngs[id]
324
348
seed! (subrng, seeds[i])
325
-
349
+
326
350
# Sample a chain and save it to the vector.
327
351
chains[i] = sample (subrng, models[id], samplers[id], N;
328
352
progress = false , kwargs... )
@@ -335,12 +359,6 @@ function psample(
335
359
progress && put! (channel, false )
336
360
end
337
361
end
338
- finally
339
- # Close the progress bar.
340
- if progress
341
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname,
342
- progress= " done" , _id= progressid)
343
- end
344
362
end
345
363
346
364
# Concatenate the chains together.
0 commit comments