Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b3434ac
[wip] fix parallel sampling
penelopeysm Jun 26, 2025
15250c7
Parallel sampling with ProgressLogging
penelopeysm Jun 27, 2025
367718b
destroy per-chain progress bars if an error occurs
penelopeysm Jun 27, 2025
60134f8
add a todo
penelopeysm Jun 27, 2025
e0ae513
Fix implementation
penelopeysm Jun 30, 2025
6b514e4
Bump minor version
penelopeysm Jun 30, 2025
bbda3c8
Add `setmaxchainsprogress!`
penelopeysm Jun 30, 2025
a9e5306
Don't duplicate macro
penelopeysm Jun 30, 2025
a03692d
:overall works with MCMCDistributed now
penelopeysm Jun 30, 2025
838db60
Give up on :perchain for MCMCDistributed
penelopeysm Jun 30, 2025
6b59b21
Fix comments
penelopeysm Jun 30, 2025
b340ebc
Remove dead code
penelopeysm Jun 30, 2025
1195503
Undelete some not-actually-dead code
penelopeysm Jun 30, 2025
594483f
Broaden UUIDs compat so that it works on older Julia versions
penelopeysm Jun 30, 2025
7def4b4
Explain progress logging in docs
penelopeysm Jun 30, 2025
022678e
Remove dead code
penelopeysm Jul 1, 2025
5b2577f
Fix channel buffering for MCMCThreads
penelopeysm Jul 1, 2025
cefafb0
Attempt to use proper types for logging
penelopeysm Jul 1, 2025
c6f9e78
Refactor logging, throttle per-chain updates
penelopeysm Jul 1, 2025
d9c2e86
Improve comment
penelopeysm Jul 1, 2025
f8a8b64
add warning
penelopeysm Jul 1, 2025
64b0bfb
fix convergence sampling
penelopeysm Jul 1, 2025
27569b3
Don't use integer division
penelopeysm Jul 1, 2025
4cd647a
remove extra show
penelopeysm Jul 1, 2025
9f8970d
Rename withprogresslogger macro
penelopeysm Jul 1, 2025
3e43f6a
Add exclamation marks to function names
penelopeysm Jul 15, 2025
7276fc2
Improve clarity of user-facing documentation
penelopeysm Jul 15, 2025
284741f
Make `:overall` the default, remove `setmaxchainsprogress!`
penelopeysm Jul 23, 2025
5c5b912
Make :perchain use the richer overall progress bar
penelopeysm Jul 23, 2025
eebb10b
Fix a typo
penelopeysm Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,45 @@ To ensure that sampling multiple chains "just works" when sampling of a single c

## Progress logging

The default value for the `progress` keyword argument is `AbstractMCMC.PROGRESS[]`, which is always set to `true` unless modified with `AbstractMCMC.setprogress!`.
For example, `setprogress!(false)` will disable all progress logging.
Progress logging is controlled in one of two ways:

```@docs
AbstractMCMC.setprogress!
```
- by passing the `progress` keyword argument to the `sample(...)` function, or
- by globally changing the defaults with `AbstractMCMC.setprogress!` and `AbstractMCMC.setmaxchainsprogress!`.

### `progress` keyword argument

For single-chain sampling (i.e., `sample([rng,] model, sampler, N)`), as well as multiple-chain sampling with `MCMCSerial`, the `progress` keyword argument should be a `Bool`.

For multiple-chain sampling using `MCMCThreads`, there are several, more detailed, options:

- `:perchain`: create one progress bar per chain being sampled
- `:perchain`: create one progress bar per chain being sampled, plus one progress bar tracking the number of chains
- `:overall`: create one progress bar for the overall sampling process, which tracks the percentage of samples that have been sampled across all chains
- `:none`: do not create any progress bar
- `true` (the default): use `perchain` for 10 or fewer chains, and `overall` for more than 10 chains
- `false`: same as `none`, i.e. no progress bar
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be "same as :none"

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also "use :perchain for 10 or fewer chains, and :overall" 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, changed.


The threshold of 10 chains can be changed using `AbstractMCMC.setmaxchainsprogress!(N)`, which will cause `MCMCThreads` to use `:perchain` for `N` or fewer chains, and `:overall` for more than `N` chains.
Thus, for example, if you _always_ want to use `:overall`, you can call `AbstractMCMC.setmaxchainsprogress!(0)`.

Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented.
So, `true` always corresponds to `:overall`, and `false` corresponds to `:none`.

!!! warning "Do not override the `progress` keyword argument"
If you are implementing your own methods for `sample(...)`, you should make sure to not override the `progress` keyword argument if you want progress logging in multi-chain sampling to work correctly, as the multi-chain `sample()` call makes sure to specifically pass custom values of `progress` to the single-chain calls.

### Global settings

If you are sampling multiple times and would like to change the default behaviour, you can use these functions to control progress logging globally:

```@docs
AbstractMCMC.setprogress!
AbstractMCMC.setmaxchainsprogress!
```

`setprogress!` is more general, and applies to all types of sampling (both single- and multiple-chain).
It only takes a boolean argument, which switches progress logging on or off.
For example, `setprogress!(false)` will disable all progress logging.

On the other hand, `setmaxchainsprogress!` is specific to multiple-chain sampling, and allows you to set the threshold for when to switch from `:perchain` to `:overall` progress logging.
Thus, for example, if you want to keep progress logging on but _always_ want to use `:overall`, you can set `AbstractMCMC.setmaxchainsprogress!(0)`.

## Chains

The `chain_type` keyword argument allows to set the type of the returned chain. A common
Expand Down
24 changes: 12 additions & 12 deletions src/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg
return new{typeof(name)}(name, UUIDs.uuid4())
end
end
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the below functions have a trailing !? They aren't mutating the parameters, but some global state, right?

(Doesn't really matter, we just as well leave the names as is)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, they are mutating (although the macro hides the worst of it). Will change.

function init_progress(p::CreateNewProgressBar)
function init_progress!(p::CreateNewProgressBar)
ProgressLogging.@logprogress p.name nothing _id = p.uuid
end
function update_progress(p::CreateNewProgressBar, progress_frac)
function update_progress!(p::CreateNewProgressBar, progress_frac)
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
end
function finish_progress(p::CreateNewProgressBar)
function finish_progress!(p::CreateNewProgressBar)
ProgressLogging.@logprogress p.name "done" _id = p.uuid
end

Expand All @@ -34,9 +34,9 @@ end
Do not log progress at all.
"""
struct NoLogging <: AbstractProgressKwarg end
init_progress(::NoLogging) = nothing
update_progress(::NoLogging, ::Any) = nothing
finish_progress(::NoLogging) = nothing
init_progress!(::NoLogging) = nothing
update_progress!(::NoLogging, ::Any) = nothing
finish_progress!(::NoLogging) = nothing

"""
ExistingProgressBar
Expand All @@ -51,7 +51,7 @@ struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg
name::S
uuid::UUIDs.UUID
end
function init_progress(p::ExistingProgressBar)
function init_progress!(p::ExistingProgressBar)
# Hacky code to reset the start timer if called from a multi-chain sampling
# process. We need this because the progress bar is constructed in the
# multi-chain method, i.e. if we don't do this the progress bar shows the
Expand All @@ -65,10 +65,10 @@ function init_progress(p::ExistingProgressBar)
end
ProgressLogging.@logprogress p.name nothing _id = p.uuid
end
function update_progress(p::ExistingProgressBar, progress_frac)
function update_progress!(p::ExistingProgressBar, progress_frac)
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
end
function finish_progress(p::ExistingProgressBar)
function finish_progress!(p::ExistingProgressBar)
ProgressLogging.@logprogress p.name "done" _id = p.uuid
end

Expand All @@ -87,11 +87,11 @@ struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{
channel::T
n_updates::Int
end
init_progress(::ChannelProgress) = nothing
update_progress(p::ChannelProgress, ::Any) = put!(p.channel, true)
init_progress!(::ChannelProgress) = nothing
update_progress!(p::ChannelProgress, ::Any) = put!(p.channel, true)
# Note: We don't want to `put!(p.channel, false)`, because that would stop the
# channel from being used for further updates e.g. from other chains.
finish_progress(::ChannelProgress) = nothing
finish_progress!(::ChannelProgress) = nothing

# Add a custom progress logger if the current logger does not seem to be able to handle
# progress logs.
Expand Down
34 changes: 17 additions & 17 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function mcmcsample(
local state

@maybewithricherlogger begin
init_progress(progress)
init_progress!(progress)
# Determine threshold values for progress logging (by default, one
# update per 0.5% of progress, unless this has been passed in
# explicitly)
Expand All @@ -195,7 +195,7 @@ function mcmcsample(
# Start the progress bar.
itotal = 1
if itotal >= next_update
update_progress(progress, itotal / Ntotal)
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end

Expand All @@ -211,7 +211,7 @@ function mcmcsample(
# Update the progress bar.
itotal += 1
if itotal >= next_update
update_progress(progress, itotal / Ntotal)
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end
Expand All @@ -237,7 +237,7 @@ function mcmcsample(
# Update progress bar.
itotal += 1
if itotal >= next_update
update_progress(progress, itotal / Ntotal)
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end
Expand All @@ -259,11 +259,11 @@ function mcmcsample(
# Update the progress bar.
itotal += 1
if itotal >= next_update
update_progress(progress, itotal / Ntotal)
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end
finish_progress(progress)
finish_progress!(progress)
end

# Get the sample stop time.
Expand Down Expand Up @@ -322,7 +322,7 @@ function mcmcsample(
local state

@maybewithricherlogger begin
init_progress(progress)
init_progress!(progress)
# Obtain the initial sample and state.
sample, state = if num_warmup > 0
if initial_state === nothing
Expand Down Expand Up @@ -385,7 +385,7 @@ function mcmcsample(
# Increment iteration counter.
i += 1
end
finish_progress(progress)
finish_progress!(progress)
end

# Get the sample stop time.
Expand Down Expand Up @@ -474,7 +474,7 @@ function mcmcsample(
# by a channel, but it is not itself a ChannelProgress (because
# ChannelProgress doesn't come with a progress bar).
overall_progress_bar = CreateNewProgressBar(progressname)
init_progress(overall_progress_bar)
init_progress!(overall_progress_bar)
# These are the per-chain progress bars. We generate `nchains`
# independent UUIDs for each progress bar
child_progresses = [
Expand All @@ -484,7 +484,7 @@ function mcmcsample(
# ProgressLogging prints from the bottom up, and we want chain 1 to
# show up at the top)
for child_progress in reverse(child_progresses)
init_progress(child_progress)
init_progress!(child_progress)
end
updates_per_chain = nothing
elseif progress == :overall
Expand Down Expand Up @@ -518,11 +518,11 @@ function mcmcsample(
while take!(progress_channel)
itotal += 1
if itotal >= next_update
update_progress(overall_progress_bar, itotal / Ntotal)
update_progress!(overall_progress_bar, itotal / Ntotal)
next_update += threshold
end
end
finish_progress(overall_progress_bar)
finish_progress!(overall_progress_bar)
end
end

Expand Down Expand Up @@ -578,7 +578,7 @@ function mcmcsample(
# Tell the 'main' progress bar that this chain is done.
put!(progress_channel, true)
# Conclude the per-chain progress bar.
finish_progress(child_progresses[chainidx])
finish_progress!(child_progresses[chainidx])
end
# Note that if progress == :overall, we don't need to do anything
# because progress on that bar is triggered by
Expand All @@ -593,7 +593,7 @@ function mcmcsample(
put!(progress_channel, false)
# Additionally stop the per-chain progress bars
for child_progress in child_progresses
finish_progress(child_progress)
finish_progress!(child_progress)
end
elseif progress == :overall
# Stop updating the main progress bar (either if sampling
Expand Down Expand Up @@ -670,7 +670,7 @@ function mcmcsample(
chan = Channel{Bool}(Distributed.nworkers())
progress_channel = Distributed.RemoteChannel(() -> chan)
overall_progress_bar = CreateNewProgressBar(progressname)
init_progress(overall_progress_bar)
init_progress!(overall_progress_bar)
# See MCMCThreads method for the rationale behind updates_per_chain.
updates_per_chain = max(1, 400 ÷ nchains)
child_progresses = [
Expand All @@ -694,11 +694,11 @@ function mcmcsample(
while take!(progress_channel)
itotal += 1
if itotal >= next_update
update_progress(overall_progress_bar, itotal / Ntotal)
update_progress!(overall_progress_bar, itotal / Ntotal)
next_update += threshold
end
end
finish_progress(overall_progress_bar)
finish_progress!(overall_progress_bar)
end
end

Expand Down
Loading