diff --git a/Project.toml b/Project.toml index 77de8700..11f1d024 100644 --- a/Project.toml +++ b/Project.toml @@ -7,27 +7,20 @@ version = "5.6.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" -ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19, 0.4" -ConsoleProgressMonitor = "0.1" FillArrays = "1" LogDensityProblems = "2" -LoggingExtras = "0.4, 0.5, 1" -ProgressLogging = "0.1" +ProgressMeter = "1.10.4" StatsBase = "0.32, 0.33, 0.34" -TerminalLoggers = "0.1" Transducers = "0.4.30" julia = "1.6" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index b7c35fb6..cf8aadba 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,17 +1,12 @@ module AbstractMCMC using BangBang: BangBang -using ConsoleProgressMonitor: ConsoleProgressMonitor using LogDensityProblems: LogDensityProblems -using LoggingExtras: LoggingExtras -using ProgressLogging: ProgressLogging using StatsBase: StatsBase -using TerminalLoggers: TerminalLoggers using Transducers: Transducers using FillArrays: FillArrays using Distributed: Distributed -using Logging: Logging using Random: Random # Reexport sample @@ -113,7 +108,6 @@ function setparams!!(model::AbstractModel, state, params) end include("samplingstats.jl") -include("logging.jl") include("interface.jl") include("sample.jl") include("stepper.jl") diff --git a/src/logging.jl b/src/logging.jl deleted file mode 100644 index 04c41187..00000000 --- a/src/logging.jl +++ /dev/null @@ -1,48 +0,0 @@ -# avoid creating a progress bar with @withprogress if progress logging is disabled -# and add a custom progress logger if the current logger does not seem to be able to handle -# progress logs -macro ifwithprogresslogger(progress, exprs...) - return esc( - quote - if $progress - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do - $ProgressLogging.@withprogress $(exprs...) - end - end - else - $(exprs[end]) - end - end, - ) -end - -# improved checks? -function hasprogresslevel(logger) - return Logging.min_enabled_level(logger) ≤ ProgressLogging.ProgressLevel -end - -# filter better, e.g., according to group? -function with_progresslogger(f, _module, logger) - logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger()) do log - log._module === _module && log.level == ProgressLogging.ProgressLevel - end - logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log - log._module !== _module || log.level != ProgressLogging.ProgressLevel - end - - return Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) -end - -function progresslogger() - # detect if code is running under IJulia since TerminalLogger does not work with IJulia - # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia - if (Sys.iswindows() && VERSION < v"1.5.3") || - (isdefined(Main, :IJulia) && Main.IJulia.inited) - return ConsoleProgressMonitor.ProgressLogger() - else - return TerminalLoggers.TerminalLogger() - end -end diff --git a/src/sample.jl b/src/sample.jl index 01a2006a..2b1dfde2 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,3 +1,5 @@ +using ProgressMeter: ProgressMeter + # Default implementations of `sample`. const PROGRESS = Ref(true) @@ -27,7 +29,7 @@ end """ sample( - rng::Random.AbatractRNG=Random.default_rng(), + rng::Random.AbstractRNG=Random.default_rng(), model::AbstractModel, sampler::AbstractSampler, N_or_isdone; @@ -119,6 +121,8 @@ function mcmcsample( thinning=1, chain_type::Type=Any, initial_state=nothing, + _chain_idx=nothing, # for use when sampling multiple chains + _progress_channel=nothing, # for use when sampling multiple chains kwargs..., ) # Check the number of requested samples. @@ -141,77 +145,54 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name = progressname begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - if progress - threshold = Ntotal ÷ 200 - next_update = threshold - end + # Set up progress logging object + progress_obj = ProgressMeter.Progress( + Ntotal; desc=progressname, dt=0.01, enabled=progress + ) - # Obtain the initial sample and state. - sample, state = if num_warmup > 0 - if initial_state === nothing - step_warmup(rng, model, sampler; kwargs...) - else - step_warmup(rng, model, sampler, initial_state; kwargs...) - end + # Obtain the initial sample and state. + sample, state = if num_warmup > 0 + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) else - if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end + step_warmup(rng, model, sampler, initial_state; kwargs...) end - - # Update the progress bar. - itotal = 1 - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold + else + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) end + end - # Discard initial samples. - for j in 1:discard_initial - # Obtain the next sample and state. - sample, state = if j ≤ num_warmup - step_warmup(rng, model, sampler, state; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end - - # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold - end + # Discard initial samples. + for j in 1:discard_initial + # Obtain the next sample and state. + sample, state = if j ≤ num_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) end - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) + # Update the progress bar (and channel if being called from a multi-chain method) + ProgressMeter.next!(progress_obj) + end - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) - samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) - - # Step through the sampler. - for i in 2:N - # Discard thinned samples. - for _ in 1:(thinning - 1) - # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) - # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold - end - end + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) + samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) + + # Update the progress bar (and channel if being called from a multi-chain method) + ProgressMeter.next!(progress_obj) + _progress_channel === nothing || put!(_progress_channel, (_chain_idx, true)) + # Step through the sampler. + for i in 2:N + # Discard thinned samples. + for _ in 1:(thinning - 1) # Obtain the next sample and state. sample, state = if i ≤ keep_from_warmup step_warmup(rng, model, sampler, state; kwargs...) @@ -219,19 +200,26 @@ function mcmcsample( step(rng, model, sampler, state; kwargs...) end - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = save!!(samples, sample, i, model, sampler, N; kwargs...) + # Update the progress bar + ProgressMeter.next!(progress_obj) + end - # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold - end + # Obtain the next sample and state. + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) end + + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + + # Save the sample. + samples = save!!(samples, sample, i, model, sampler, N; kwargs...) + + # Update the progress bar (and channel if being called from a multi-chain method) + ProgressMeter.next!(progress_obj) + _progress_channel === nothing || put!(_progress_channel, (_chain_idx, true)) end # Get the sample stop time. @@ -239,6 +227,10 @@ function mcmcsample( duration = stop - start stats = SamplingStats(start, stop, duration) + # Stop the progress bar + ProgressMeter.finish!(progress_obj; keep=true) + _progress_channel === nothing || put!(_progress_channel, (_chain_idx, false)) + return bundle_samples( samples, model, @@ -278,73 +270,79 @@ function mcmcsample( discard_from_warmup = min(num_warmup, discard_initial) keep_from_warmup = num_warmup - discard_from_warmup + # Set up progress logging object (if needed) + progress_obj = ProgressMeter.ProgressUnknown(; desc=progressname, enabled=progress) + # Start the timer start = time() local state - @ifwithprogresslogger progress name = progressname begin - # Obtain the initial sample and state. - sample, state = if num_warmup > 0 - if initial_state === nothing - step_warmup(rng, model, sampler; kwargs...) - else - step_warmup(rng, model, sampler, initial_state; kwargs...) - end + # Obtain the initial sample and state. + sample, state = if num_warmup > 0 + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) else - if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end + step_warmup(rng, model, sampler, initial_state; kwargs...) end + else + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) + end + end - # Discard initial samples. - for j in 1:discard_initial - # Obtain the next sample and state. - sample, state = if j ≤ num_warmup - step_warmup(rng, model, sampler, state; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + # Discard initial samples. + for j in 1:discard_initial + # Obtain the next sample and state. + sample, state = if j ≤ num_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) end + end - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, 1, model, sampler; kwargs...) - - # Step through the sampler until stopping. - i = 2 - while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) - # Discard thinned samples. - for _ in 1:(thinning - 1) - # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end - end + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, 1, model, sampler; kwargs...) + + # Update the progress bar. + ProgressMeter.next!(progress_obj) + # Step through the sampler until stopping. + i = 2 + while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) + # Discard thinned samples. + for _ in 1:(thinning - 1) # Obtain the next sample and state. sample, state = if i ≤ keep_from_warmup step_warmup(rng, model, sampler, state; kwargs...) else step(rng, model, sampler, state; kwargs...) end + end + + # Obtain the next sample and state. + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end + + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) + # Save the sample. + samples = save!!(samples, sample, i, model, sampler; kwargs...) - # Save the sample. - samples = save!!(samples, sample, i, model, sampler; kwargs...) + # Increment iteration counter. + i += 1 - # Increment iteration counter. - i += 1 - end + # Update the progress bar. + ProgressMeter.next!(progress_obj) end # Get the sample stop time. @@ -352,6 +350,9 @@ function mcmcsample( duration = stop - start stats = SamplingStats(start, stop, duration) + # Stop the progress bar + ProgressMeter.finish!(progress_obj) + # Wrap the samples up. return bundle_samples( samples, @@ -411,78 +412,93 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger progress name = progressname begin - # Create a channel for progress logging. - if progress - channel = Channel{Bool}(length(interval)) + # Create overall progress logging object (tracks number of chains completed) + overall_progress_obj = ProgressMeter.Progress( + nchains; desc=progressname, dt=0.0, enabled=progress + ) + # ProgressMeter doesn't start printing until the second iteration or so. This + # forces it to start printing an empty progress bar immediately. + # https://github.com/timholy/ProgressMeter.jl/issues/288 + ProgressMeter.update!(overall_progress_obj, 0; force=true) + # Create per-chain progress logging objects + progress_objs = [ + ProgressMeter.Progress( + N; desc="Chain $i/$nchains", dt=0.01, enabled=progress, offset=i + ) for i in 1:nchains + ] + for obj in progress_objs + ProgressMeter.update!(obj, 0; force=true) + end + # Create a channel to synchronise progress updates. + channel = Distributed.RemoteChannel(() -> Channel{Tuple{Int,Bool}}(), 1) + + Distributed.@sync begin + Distributed.@async while true + i, res = take!(channel) + # i == 0 means the overall progress bar; i > 0 means the + # progress bar for chain i. + prog_obj = if i == 0 + overall_progress_obj + else + progress_objs[i] + end + if res # true = a chain / sample finished + ProgressMeter.next!(prog_obj) + else # false = all chains / samples finished (or one failed) + ProgressMeter.finish!(prog_obj) + i == 0 && break + end end - Distributed.@sync begin - if progress - # Update the progress bar. - Distributed.@async begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold - end + Distributed.@async begin + try + Distributed.@sync for (i, _rng, _model, _sampler) in + zip(interval, rngs, models, samplers) + if i <= n + chainidx_hi = i * (m + 1) + nchains_chunk = m + 1 + else + chainidx_hi = i * m + n # n * (m + 1) + (i - n) * m + nchains_chunk = m end - end - end + chainidx_lo = chainidx_hi - nchains_chunk + 1 + chainidxs = chainidx_lo:chainidx_hi + + Threads.@spawn for chainidx in chainidxs + # Seed the chunk-specific random number generator with the pre-made seed. + Random.seed!(_rng, seeds[chainidx]) + + # Sample a chain and save it to the vector. + chains[chainidx] = StatsBase.sample( + _rng, + _model, + _sampler, + N; + progress=false, + progressname="Chain $chainidx/$nchains", + # use these to allow each chain to update progress bar + _chain_idx=chainidx, + _progress_channel=channel, + initial_params=if initial_params === nothing + nothing + else + initial_params[chainidx] + end, + initial_state=if initial_state === nothing + nothing + else + initial_state[chainidx] + end, + kwargs..., + ) - Distributed.@async begin - try - Distributed.@sync for (i, _rng, _model, _sampler) in - zip(interval, rngs, models, samplers) - if i <= n - chainidx_hi = i * (m + 1) - nchains_chunk = m + 1 - else - chainidx_hi = i * m + n # n * (m + 1) + (i - n) * m - nchains_chunk = m - end - chainidx_lo = chainidx_hi - nchains_chunk + 1 - chainidxs = chainidx_lo:chainidx_hi - - Threads.@spawn for chainidx in chainidxs - # Seed the chunk-specific random number generator with the pre-made seed. - Random.seed!(_rng, seeds[chainidx]) - - # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample( - _rng, - _model, - _sampler, - N; - progress=false, - initial_params=if initial_params === nothing - nothing - else - initial_params[chainidx] - end, - initial_state=if initial_state === nothing - nothing - else - initial_state[chainidx] - end, - kwargs..., - ) - - # Update the progress bar. - progress && put!(channel, true) - end + # Update the overall progress bar. + put!(channel, (0, true)) end - finally - # Stop updating the progress bar. - progress && put!(channel, false) end + finally + # Stop updating the overall progress bar. + put!(channel, (0, false)) end end end @@ -530,63 +546,77 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger progress name = progressname begin - # Create a channel for progress logging. - if progress - channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) - end - Distributed.@sync begin - if progress - # Update the progress bar. - Distributed.@async begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold - end - end - end + # Create overall progress logging object (tracks number of chains completed) + overall_progress_obj = ProgressMeter.Progress( + nchains; desc=progressname, dt=0.0, enabled=progress + ) + # ProgressMeter doesn't start printing until the second iteration or so. This + # forces it to start printing an empty progress bar immediately. + # https://github.com/timholy/ProgressMeter.jl/issues/288 + ProgressMeter.update!(overall_progress_obj, 0; force=true) + # Create per-chain progress logging objects + progress_objs = [ + ProgressMeter.Progress( + N; desc="Chain $i/$nchains", dt=0.01, enabled=progress, offset=i + ) for i in 1:nchains + ] + for obj in progress_objs + ProgressMeter.update!(obj, 0; force=true) + end + # Create a channel to synchronise progress updates. + channel = Distributed.RemoteChannel(() -> Channel{Tuple{Int,Bool}}(), 1) + + Distributed.@sync begin + Distributed.@async while true + i, res = take!(channel) + # i == 0 means the overall progress bar; i > 0 means the + # progress bar for chain i. + prog_obj = if i == 0 + overall_progress_obj + else + progress_objs[i] end + if res # true = a chain / sample finished + ProgressMeter.next!(prog_obj) + else # false = all chains / samples finished (or one failed) + ProgressMeter.finish!(prog_obj; keep=true) + i == 0 && break + end + end - Distributed.@async begin - try - function sample_chain(seed, initial_params, initial_state) - # Seed a new random number generator with the pre-made seed. - Random.seed!(rng, seed) - - # Sample a chain. - chain = StatsBase.sample( - rng, - model, - sampler, - N; - progress=false, - initial_params=initial_params, - initial_state=initial_state, - kwargs..., - ) + Distributed.@async begin + try + function sample_chain(i, seed, initial_params, initial_state) + # Seed a new random number generator with the pre-made seed. + Random.seed!(rng, seed) + + # Sample a chain. + chain = StatsBase.sample( + rng, + model, + sampler, + N; + progress=false, + initial_params=initial_params, + initial_state=initial_state, + _chain_idx=i, + _progress_channel=channel, + kwargs..., + ) - # Update the progress bar. - progress && put!(channel, true) + # Update the progress bar. + progress && put!(channel, (0, true)) - # Return the new chain. - return chain - end - chains = Distributed.pmap( - sample_chain, pool, seeds, _initial_params, _initial_state - ) - finally - # Stop updating the progress bar. - progress && put!(channel, false) + # Return the new chain. + return chain end + chains = Distributed.pmap( + sample_chain, pool, 1:nchains, seeds, _initial_params, _initial_state + ) + finally + # Stop updating the progress bar. + progress && put!(channel, (0, false)) end end end @@ -602,7 +632,7 @@ function mcmcsample( ::MCMCSerial, N::Integer, nchains::Integer; - progressname="Sampling", + progress=PROGRESS[], initial_params=nothing, initial_state=nothing, kwargs..., @@ -635,7 +665,8 @@ function mcmcsample( model, sampler, N; - progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), + progress=progress, + progressname="Chain $i/$nchains", initial_params=initial_params, initial_state=initial_state, kwargs...,