diff --git a/Project.toml b/Project.toml index 66ada2dd..7bf476a0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.7.0" +version = "5.7.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/logging.jl b/src/logging.jl index 4a8b6822..79db88c1 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -6,6 +6,8 @@ internally take for single-chain sampling. """ abstract type AbstractProgressKwarg end +DEFAULT_N_UPDATES = 200 + """ CreateNewProgressBar @@ -27,7 +29,7 @@ end function finish_progress!(p::CreateNewProgressBar) ProgressLogging.@logprogress p.name "done" _id = p.uuid end -get_n_updates(::CreateNewProgressBar) = 200 +get_n_updates(::CreateNewProgressBar) = DEFAULT_N_UPDATES """ NoLogging @@ -38,7 +40,7 @@ struct NoLogging <: AbstractProgressKwarg end init_progress!(::NoLogging) = nothing update_progress!(::NoLogging, ::Any) = nothing finish_progress!(::NoLogging) = nothing -get_n_updates(::NoLogging) = 200 +get_n_updates(::NoLogging) = DEFAULT_N_UPDATES """ ExistingProgressBar @@ -72,7 +74,7 @@ end function finish_progress!(p::ExistingProgressBar) ProgressLogging.@logprogress p.name "done" _id = p.uuid end -get_n_updates(::ExistingProgressBar) = 200 +get_n_updates(::ExistingProgressBar) = DEFAULT_N_UPDATES """ ChannelProgress diff --git a/src/sample.jl b/src/sample.jl index d09a30ab..52f1e35c 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -452,13 +452,12 @@ function mcmcsample( progress_channel = Channel{Bool}(nchains) overall_progress_bar = CreateNewProgressBar(progressname) # If we have many chains and many samples, we don't want to force - # each chain to report back to the main thread for each sample, as - # this would cause serious performance issues due to lock conflicts. - # In the overall progress bar we only expect 200 updates (i.e., one - # update per 0.5%). To avoid possible throttling issues we ask for - # twice the amount needed per chain, which doesn't cause a real - # performance hit. - updates_per_chain = max(1, 400 ÷ nchains) + # each chain to report back to the main thread for each sample, as this would + # cause serious performance issues due to lock conflicts. In the overall + # progress bar we only expect N updates (by default N = 200, i.e., one update + # per 0.5%). To avoid possible throttling issues we ask for twice + # the amount needed per chain, which doesn't cause a real performance hit. + updates_per_chain = max(1, (2 * get_n_updates(overall_progress_bar)) ÷ nchains) init_progress!(overall_progress_bar) end if progress == :perchain @@ -483,7 +482,7 @@ function mcmcsample( Ntotal = nchains * updates_per_chain # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = Ntotal / 200 + threshold = Ntotal / get_n_updates(overall_progress_bar) next_update = threshold itotal = 0 @@ -633,7 +632,7 @@ function mcmcsample( overall_progress_bar = CreateNewProgressBar(progressname) init_progress!(overall_progress_bar) # See MCMCThreads method for the rationale behind updates_per_chain. - updates_per_chain = max(1, 400 ÷ nchains) + updates_per_chain = max(1, (2 * get_n_updates(overall_progress_bar)) ÷ nchains) child_progresses = [ ChannelProgress(progress_channel, updates_per_chain) for _ in 1:nchains ] @@ -646,9 +645,8 @@ function mcmcsample( # This task updates the progress bar Distributed.@async begin # Determine threshold values for progress logging - # (one update per 0.5% of progress) Ntotal = nchains * updates_per_chain - threshold = Ntotal / 200 + threshold = Ntotal / get_n_updates(overall_progress_bar) next_update = threshold itotal = 0