1
1
module AbstractMCMC
2
2
3
- using ProgressMeter
3
+ import ProgressLogging
4
4
import StatsBase
5
5
using StatsBase: sample
6
6
7
+ import Distributed
8
+ import Logging
7
9
using Random: GLOBAL_RNG, AbstractRNG, seed!
10
+ import UUIDs
8
11
9
12
"""
10
13
AbstractChains
@@ -36,96 +39,17 @@ An `AbstractModel` represents a generic model type that can be used to perform i
36
39
"""
37
40
abstract type AbstractModel end
38
41
39
- """
40
- AbstractCallback
41
-
42
- An `AbstractCallback` types is a supertype to be inherited from if you want to use custom callback
43
- functionality. This is used to report sampling progress such as parameters calculated, remaining
44
- samples to run, or even plot graphs if you so choose.
45
-
46
- In order to implement callback functionality, you need the following:
47
-
48
- - A mutable struct that is a subtype of `AbstractCallback`
49
- - An overload of the `init_callback` function
50
- - An overload of the `callback` function
51
- """
52
- abstract type AbstractCallback end
53
-
54
- """
55
- NoCallback()
56
-
57
- This disables the callback functionality in the event that you wish to
58
- implement your own callback or reporting.
59
- """
60
- mutable struct NoCallback <: AbstractCallback end
61
-
62
- """
63
- DefaultCallback(N::Int)
64
-
65
- The default callback struct which uses `ProgressMeter`.
66
- """
67
- mutable struct DefaultCallback{
68
- ProgType<: ProgressMeter.AbstractProgress
69
- } <: AbstractCallback
70
- p :: ProgType
71
- end
72
-
73
- DefaultCallback (N:: Int ) = DefaultCallback (ProgressMeter. Progress (N, 1 ))
74
-
75
- function init_callback (
76
- rng:: AbstractRNG ,
77
- ℓ:: ModelType ,
78
- s:: SamplerType ,
79
- N:: Integer ;
80
- kwargs...
81
- ) where {ModelType<: AbstractModel , SamplerType<: AbstractSampler }
82
- return DefaultCallback (N)
83
- end
84
-
85
- """
86
- _generate_callback(
87
- rng::AbstractRNG,
88
- ℓ::ModelType,
89
- s::SamplerType,
90
- N::Integer;
91
- progress_style=:default,
92
- kwargs...
93
- )
94
-
95
- `_generate_callback` uses a `progress_style` keyword argument to determine
96
- which progress meter style should be used. This function is strictly internal
97
- and is not meant to be overloaded. If you intend to add a custom `AbstractCallback`,
98
- you should overload `init_callback` instead.
99
-
100
- Options for `progress_style` include:
101
-
102
- - `:default` which returns the result of `init_callback`
103
- - `false` or `:disable` which returns a `NoCallback`
104
- - `:plain` which returns the default, simple `DefaultCallback`.
105
- """
106
- function _generate_callback (
107
- rng:: AbstractRNG ,
108
- ℓ:: ModelType ,
109
- s:: SamplerType ,
110
- N:: Integer ;
111
- progress_style= :default ,
112
- kwargs...
113
- ) where {ModelType<: AbstractModel , SamplerType<: AbstractSampler }
114
- if progress_style == :default
115
- return init_callback (rng, ℓ, s, N; kwargs... )
116
- elseif progress_style == false || progress_style == :disable
117
- return NoCallback ()
118
- elseif progress_style == :plain
119
- return DefaultCallback (N)
120
- else
121
- throw (ArgumentError (" Keyword argument $progress_style is not recognized." ))
122
- end
123
- end
124
-
125
42
"""
126
43
sample([rng, ]model, sampler, N; kwargs...)
127
44
128
45
Return `N` samples from the MCMC `sampler` for the provided `model`.
46
+
47
+ If a callback function `f` with type signature
48
+ ```julia
49
+ f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
50
+ iteration::Integer, transition; kwargs...)
51
+ ```
52
+ may be provided as keyword argument `callback`. It is called after every sampling step.
129
53
"""
130
54
function StatsBase. sample (
131
55
model:: AbstractModel ,
@@ -141,7 +65,9 @@ function StatsBase.sample(
141
65
model:: AbstractModel ,
142
66
sampler:: AbstractSampler ,
143
67
N:: Integer ;
144
- progress:: Bool = true ,
68
+ progress = true ,
69
+ progressname = " Sampling" ,
70
+ callback = (args... ; kwargs... ) -> nothing ,
145
71
chain_type:: Type = Any,
146
72
kwargs...
147
73
)
@@ -151,29 +77,54 @@ function StatsBase.sample(
151
77
# Perform any necessary setup.
152
78
sample_init! (rng, model, sampler, N; kwargs... )
153
79
154
- # Add a progress meter.
155
- progress && (cb = _generate_callback (rng, model, sampler, N; kwargs... ))
156
-
157
- # Obtain the initial transition.
158
- transition = step! (rng, model, sampler, N; iteration= 1 , kwargs... )
159
-
160
- # Save the transition.
161
- transitions = transitions_init (transition, model, sampler, N; kwargs... )
162
- transitions_save! (transitions, 1 , transition, model, sampler, N; kwargs... )
80
+ # Create a progress bar.
81
+ if progress
82
+ progressid = UUIDs. uuid4 ()
83
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= NaN ,
84
+ _id= progressid)
85
+ end
163
86
164
- # Update the progress meter.
165
- progress && callback (rng, model, sampler, N, 1 , transition, cb; kwargs... )
87
+ local transitions
88
+ try
89
+ # Obtain the initial transition.
90
+ transition = step! (rng, model, sampler, N; iteration= 1 , kwargs... )
166
91
167
- # Step through the sampler.
168
- for i in 2 : N
169
- # Obtain the next transition.
170
- transition = step! (rng, model, sampler, N, transition; iteration= i, kwargs... )
92
+ # Run callback.
93
+ callback (rng, model, sampler, N, 1 , transition; kwargs... )
171
94
172
95
# Save the transition.
173
- transitions_save! (transitions, i, transition, model, sampler, N; kwargs... )
174
-
175
- # Update the progress meter.
176
- progress && callback (rng, model, sampler, N, i, transition, cb; kwargs... )
96
+ transitions = transitions_init (transition, model, sampler, N; kwargs... )
97
+ transitions_save! (transitions, 1 , transition, model, sampler, N; kwargs... )
98
+
99
+ # Update the progress bar.
100
+ if progress
101
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= 1 / N,
102
+ _id= progressid)
103
+ end
104
+
105
+ # Step through the sampler.
106
+ for i in 2 : N
107
+ # Obtain the next transition.
108
+ transition = step! (rng, model, sampler, N, transition; iteration= i, kwargs... )
109
+
110
+ # Run callback.
111
+ callback (rng, model, sampler, N, i, transition; kwargs... )
112
+
113
+ # Save the transition.
114
+ transitions_save! (transitions, i, transition, model, sampler, N; kwargs... )
115
+
116
+ # Update the progress bar.
117
+ if progress
118
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= i/ N,
119
+ _id= progressid)
120
+ end
121
+ end
122
+ finally
123
+ # Close the progress bar.
124
+ if progress
125
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= " done" ,
126
+ _id= progressid)
127
+ end
177
128
end
178
129
179
130
# Wrap up the sampler, if necessary.
@@ -301,57 +252,6 @@ function transitions_save!(
301
252
return
302
253
end
303
254
304
- """
305
- callback(
306
- rng::AbstractRNG,
307
- ℓ::ModelType,
308
- s::SamplerType,
309
- N::Integer,
310
- iteration::Integer,
311
- cb::CallbackType;
312
- kwargs...
313
- )
314
-
315
- `callback` is called after every sample run, and allows you to run some function on a
316
- subtype of `AbstractCallback`. Typically this is used to increment a progress meter, show a
317
- plot of parameter draws, or otherwise provide information about the sampling process to the user.
318
-
319
- By default, `ProgressMeter` is used to show the number of samples remaning.
320
- """
321
- function callback (
322
- rng:: AbstractRNG ,
323
- ℓ:: ModelType ,
324
- s:: SamplerType ,
325
- N:: Integer ,
326
- iteration:: Integer ,
327
- transition,
328
- cb:: CallbackType ;
329
- kwargs...
330
- ) where {
331
- ModelType<: AbstractModel ,
332
- SamplerType<: AbstractSampler ,
333
- CallbackType<: AbstractCallback ,
334
- }
335
- # Default callback behavior.
336
- ProgressMeter. next! (cb. p)
337
- end
338
-
339
- function callback (
340
- rng:: AbstractRNG ,
341
- ℓ:: ModelType ,
342
- s:: SamplerType ,
343
- N:: Integer ,
344
- iteration:: Integer ,
345
- transition,
346
- cb:: NoCallback ;
347
- kwargs...
348
- ) where {
349
- ModelType<: AbstractModel ,
350
- SamplerType<: AbstractSampler ,
351
- }
352
- # Do nothing.
353
- end
354
-
355
255
"""
356
256
psample([rng::AbstractRNG, ]model::AbstractModel, sampler::AbstractSampler, N::Integer,
357
257
nchains::Integer; kwargs...)
@@ -377,6 +277,8 @@ function psample(
377
277
sampler:: AbstractSampler ,
378
278
N:: Integer ,
379
279
nchains:: Integer ;
280
+ progress = true ,
281
+ progressname = " Parallel sampling" ,
380
282
kwargs...
381
283
)
382
284
# Copy the random number generator, model, and sample for each thread
@@ -390,16 +292,55 @@ function psample(
390
292
# Set up a chains vector.
391
293
chains = Vector {Any} (undef, nchains)
392
294
393
- Threads. @threads for i in 1 : nchains
394
- # Obtain the ID of the current thread.
395
- id = Threads. threadid ()
295
+ # Create a progress bar and a channel for progress logging.
296
+ if progress
297
+ progressid = UUIDs. uuid4 ()
298
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= NaN ,
299
+ _id= progressid)
300
+ channel = Distributed. RemoteChannel (() -> Channel {Bool} (nchains), 1 )
301
+ end
396
302
397
- # Seed the thread-specific random number generator with the pre-made seed.
398
- subrng = rngs[id]
399
- seed! (subrng, seeds[i])
400
-
401
- # Sample a chain and save it to the vector.
402
- chains[i] = sample (subrng, models[id] , samplers[id], N; progress= false , kwargs... )
303
+ try
304
+ Distributed. @sync begin
305
+ if progress
306
+ Distributed. @async begin
307
+ # Update the progress bar.
308
+ progresschains = 0
309
+ while take! (channel)
310
+ progresschains += 1
311
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname,
312
+ progress= progresschains/ nchains, _id= progressid)
313
+ end
314
+ end
315
+ end
316
+
317
+ Distributed. @async begin
318
+ Threads. @threads for i in 1 : nchains
319
+ # Obtain the ID of the current thread.
320
+ id = Threads. threadid ()
321
+
322
+ # Seed the thread-specific random number generator with the pre-made seed.
323
+ subrng = rngs[id]
324
+ seed! (subrng, seeds[i])
325
+
326
+ # Sample a chain and save it to the vector.
327
+ chains[i] = sample (subrng, models[id], samplers[id], N;
328
+ progress = false , kwargs... )
329
+
330
+ # Update the progress bar.
331
+ progress && put! (channel, true )
332
+ end
333
+
334
+ # Stop updating the progress bar.
335
+ progress && put! (channel, false )
336
+ end
337
+ end
338
+ finally
339
+ # Close the progress bar.
340
+ if progress
341
+ Logging. @logmsg (ProgressLogging. ProgressLevel, progressname,
342
+ progress= " done" , _id= progressid)
343
+ end
403
344
end
404
345
405
346
# Concatenate the chains together.
0 commit comments