Skip to content

Commit c867a65

Browse files
authored
Merge pull request #17 from TuringLang/progresslogging
Use ProgressLogging instead of ProgressMeter
2 parents 9f8c6fa + d98a0a9 commit c867a65

File tree

3 files changed

+142
-176
lines changed

3 files changed

+142
-176
lines changed

Project.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "0.4.0"
6+
version = "0.5.0"
77

88
[deps]
9-
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
9+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
10+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
11+
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1012
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1113
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
14+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1215

1316
[compat]
14-
ProgressMeter = "1.2"
17+
ProgressLogging = "0.1"
1518
StatsBase = "0.32"
1619
julia = "1"
1720

1821
[extras]
1922
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
23+
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2024
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2125

2226
[targets]
23-
test = ["Statistics", "Test"]
27+
test = ["Statistics", "Test", "TerminalLoggers"]

src/AbstractMCMC.jl

Lines changed: 108 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
module AbstractMCMC
22

3-
using ProgressMeter
3+
import ProgressLogging
44
import StatsBase
55
using StatsBase: sample
66

7+
import Distributed
8+
import Logging
79
using Random: GLOBAL_RNG, AbstractRNG, seed!
10+
import UUIDs
811

912
"""
1013
AbstractChains
@@ -36,96 +39,17 @@ An `AbstractModel` represents a generic model type that can be used to perform i
3639
"""
3740
abstract type AbstractModel end
3841

39-
"""
40-
AbstractCallback
41-
42-
An `AbstractCallback` types is a supertype to be inherited from if you want to use custom callback
43-
functionality. This is used to report sampling progress such as parameters calculated, remaining
44-
samples to run, or even plot graphs if you so choose.
45-
46-
In order to implement callback functionality, you need the following:
47-
48-
- A mutable struct that is a subtype of `AbstractCallback`
49-
- An overload of the `init_callback` function
50-
- An overload of the `callback` function
51-
"""
52-
abstract type AbstractCallback end
53-
54-
"""
55-
NoCallback()
56-
57-
This disables the callback functionality in the event that you wish to
58-
implement your own callback or reporting.
59-
"""
60-
mutable struct NoCallback <: AbstractCallback end
61-
62-
"""
63-
DefaultCallback(N::Int)
64-
65-
The default callback struct which uses `ProgressMeter`.
66-
"""
67-
mutable struct DefaultCallback{
68-
ProgType<:ProgressMeter.AbstractProgress
69-
} <: AbstractCallback
70-
p :: ProgType
71-
end
72-
73-
DefaultCallback(N::Int) = DefaultCallback(ProgressMeter.Progress(N, 1))
74-
75-
function init_callback(
76-
rng::AbstractRNG,
77-
::ModelType,
78-
s::SamplerType,
79-
N::Integer;
80-
kwargs...
81-
) where {ModelType<:AbstractModel, SamplerType<:AbstractSampler}
82-
return DefaultCallback(N)
83-
end
84-
85-
"""
86-
_generate_callback(
87-
rng::AbstractRNG,
88-
ℓ::ModelType,
89-
s::SamplerType,
90-
N::Integer;
91-
progress_style=:default,
92-
kwargs...
93-
)
94-
95-
`_generate_callback` uses a `progress_style` keyword argument to determine
96-
which progress meter style should be used. This function is strictly internal
97-
and is not meant to be overloaded. If you intend to add a custom `AbstractCallback`,
98-
you should overload `init_callback` instead.
99-
100-
Options for `progress_style` include:
101-
102-
- `:default` which returns the result of `init_callback`
103-
- `false` or `:disable` which returns a `NoCallback`
104-
- `:plain` which returns the default, simple `DefaultCallback`.
105-
"""
106-
function _generate_callback(
107-
rng::AbstractRNG,
108-
::ModelType,
109-
s::SamplerType,
110-
N::Integer;
111-
progress_style=:default,
112-
kwargs...
113-
) where {ModelType<:AbstractModel, SamplerType<:AbstractSampler}
114-
if progress_style == :default
115-
return init_callback(rng, ℓ, s, N; kwargs...)
116-
elseif progress_style == false || progress_style == :disable
117-
return NoCallback()
118-
elseif progress_style == :plain
119-
return DefaultCallback(N)
120-
else
121-
throw(ArgumentError("Keyword argument $progress_style is not recognized."))
122-
end
123-
end
124-
12542
"""
12643
sample([rng, ]model, sampler, N; kwargs...)
12744
12845
Return `N` samples from the MCMC `sampler` for the provided `model`.
46+
47+
If a callback function `f` with type signature
48+
```julia
49+
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
50+
iteration::Integer, transition; kwargs...)
51+
```
52+
may be provided as keyword argument `callback`. It is called after every sampling step.
12953
"""
13054
function StatsBase.sample(
13155
model::AbstractModel,
@@ -141,7 +65,9 @@ function StatsBase.sample(
14165
model::AbstractModel,
14266
sampler::AbstractSampler,
14367
N::Integer;
144-
progress::Bool=true,
68+
progress = true,
69+
progressname = "Sampling",
70+
callback = (args...; kwargs...) -> nothing,
14571
chain_type::Type=Any,
14672
kwargs...
14773
)
@@ -151,29 +77,54 @@ function StatsBase.sample(
15177
# Perform any necessary setup.
15278
sample_init!(rng, model, sampler, N; kwargs...)
15379

154-
# Add a progress meter.
155-
progress && (cb = _generate_callback(rng, model, sampler, N; kwargs...))
156-
157-
# Obtain the initial transition.
158-
transition = step!(rng, model, sampler, N; iteration=1, kwargs...)
159-
160-
# Save the transition.
161-
transitions = transitions_init(transition, model, sampler, N; kwargs...)
162-
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)
80+
# Create a progress bar.
81+
if progress
82+
progressid = UUIDs.uuid4()
83+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
84+
_id=progressid)
85+
end
16386

164-
# Update the progress meter.
165-
progress && callback(rng, model, sampler, N, 1, transition, cb; kwargs...)
87+
local transitions
88+
try
89+
# Obtain the initial transition.
90+
transition = step!(rng, model, sampler, N; iteration=1, kwargs...)
16691

167-
# Step through the sampler.
168-
for i in 2:N
169-
# Obtain the next transition.
170-
transition = step!(rng, model, sampler, N, transition; iteration=i, kwargs...)
92+
# Run callback.
93+
callback(rng, model, sampler, N, 1, transition; kwargs...)
17194

17295
# Save the transition.
173-
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
174-
175-
# Update the progress meter.
176-
progress && callback(rng, model, sampler, N, i, transition, cb; kwargs...)
96+
transitions = transitions_init(transition, model, sampler, N; kwargs...)
97+
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)
98+
99+
# Update the progress bar.
100+
if progress
101+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=1/N,
102+
_id=progressid)
103+
end
104+
105+
# Step through the sampler.
106+
for i in 2:N
107+
# Obtain the next transition.
108+
transition = step!(rng, model, sampler, N, transition; iteration=i, kwargs...)
109+
110+
# Run callback.
111+
callback(rng, model, sampler, N, i, transition; kwargs...)
112+
113+
# Save the transition.
114+
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
115+
116+
# Update the progress bar.
117+
if progress
118+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=i/N,
119+
_id=progressid)
120+
end
121+
end
122+
finally
123+
# Close the progress bar.
124+
if progress
125+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress="done",
126+
_id=progressid)
127+
end
177128
end
178129

179130
# Wrap up the sampler, if necessary.
@@ -301,57 +252,6 @@ function transitions_save!(
301252
return
302253
end
303254

304-
"""
305-
callback(
306-
rng::AbstractRNG,
307-
ℓ::ModelType,
308-
s::SamplerType,
309-
N::Integer,
310-
iteration::Integer,
311-
cb::CallbackType;
312-
kwargs...
313-
)
314-
315-
`callback` is called after every sample run, and allows you to run some function on a
316-
subtype of `AbstractCallback`. Typically this is used to increment a progress meter, show a
317-
plot of parameter draws, or otherwise provide information about the sampling process to the user.
318-
319-
By default, `ProgressMeter` is used to show the number of samples remaning.
320-
"""
321-
function callback(
322-
rng::AbstractRNG,
323-
::ModelType,
324-
s::SamplerType,
325-
N::Integer,
326-
iteration::Integer,
327-
transition,
328-
cb::CallbackType;
329-
kwargs...
330-
) where {
331-
ModelType<:AbstractModel,
332-
SamplerType<:AbstractSampler,
333-
CallbackType<:AbstractCallback,
334-
}
335-
# Default callback behavior.
336-
ProgressMeter.next!(cb.p)
337-
end
338-
339-
function callback(
340-
rng::AbstractRNG,
341-
::ModelType,
342-
s::SamplerType,
343-
N::Integer,
344-
iteration::Integer,
345-
transition,
346-
cb::NoCallback;
347-
kwargs...
348-
) where {
349-
ModelType<:AbstractModel,
350-
SamplerType<:AbstractSampler,
351-
}
352-
# Do nothing.
353-
end
354-
355255
"""
356256
psample([rng::AbstractRNG, ]model::AbstractModel, sampler::AbstractSampler, N::Integer,
357257
nchains::Integer; kwargs...)
@@ -377,6 +277,8 @@ function psample(
377277
sampler::AbstractSampler,
378278
N::Integer,
379279
nchains::Integer;
280+
progress = true,
281+
progressname = "Parallel sampling",
380282
kwargs...
381283
)
382284
# Copy the random number generator, model, and sample for each thread
@@ -390,16 +292,55 @@ function psample(
390292
# Set up a chains vector.
391293
chains = Vector{Any}(undef, nchains)
392294

393-
Threads.@threads for i in 1:nchains
394-
# Obtain the ID of the current thread.
395-
id = Threads.threadid()
295+
# Create a progress bar and a channel for progress logging.
296+
if progress
297+
progressid = UUIDs.uuid4()
298+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
299+
_id=progressid)
300+
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
301+
end
396302

397-
# Seed the thread-specific random number generator with the pre-made seed.
398-
subrng = rngs[id]
399-
seed!(subrng, seeds[i])
400-
401-
# Sample a chain and save it to the vector.
402-
chains[i] = sample(subrng, models[id] , samplers[id], N; progress=false, kwargs...)
303+
try
304+
Distributed.@sync begin
305+
if progress
306+
Distributed.@async begin
307+
# Update the progress bar.
308+
progresschains = 0
309+
while take!(channel)
310+
progresschains += 1
311+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
312+
progress=progresschains/nchains, _id=progressid)
313+
end
314+
end
315+
end
316+
317+
Distributed.@async begin
318+
Threads.@threads for i in 1:nchains
319+
# Obtain the ID of the current thread.
320+
id = Threads.threadid()
321+
322+
# Seed the thread-specific random number generator with the pre-made seed.
323+
subrng = rngs[id]
324+
seed!(subrng, seeds[i])
325+
326+
# Sample a chain and save it to the vector.
327+
chains[i] = sample(subrng, models[id], samplers[id], N;
328+
progress = false, kwargs...)
329+
330+
# Update the progress bar.
331+
progress && put!(channel, true)
332+
end
333+
334+
# Stop updating the progress bar.
335+
progress && put!(channel, false)
336+
end
337+
end
338+
finally
339+
# Close the progress bar.
340+
if progress
341+
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
342+
progress="done", _id=progressid)
343+
end
403344
end
404345

405346
# Concatenate the chains together.

0 commit comments

Comments
 (0)