Skip to content

Commit 1539c2a

Browse files
committed
Add setmaxchainsprogress!
1 parent 3a5c243 commit 1539c2a

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

src/sample.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Default implementations of `sample`.
22
const PROGRESS = Ref(true)
3+
const MAX_CHAINS_PROGRESS = Ref(10)
34

45
_pluralise(n; singular="", plural="s") = n == 1 ? singular : plural
56

@@ -17,6 +18,25 @@ function setprogress!(progress::Bool; silent::Bool=false)
1718
return progress
1819
end
1920

21+
"""
22+
setmaxchainsprogress!(max_chains::Int, silent::Bool=false)
23+
24+
Set the maximum number of chains to display progress bars for when sampling
25+
multiple chains at once (if progress logging is enabled). Above this limit, no
26+
progress bars are displayed for individual chains; instead, a single progress
27+
bar is displayed for the entire sampling process.
28+
"""
29+
function setmaxchainsprogress!(max_chains::Int, silent::Bool=false)
30+
if max_chains < 0
31+
throw(ArgumentError("maximum number of chains must be non-negative"))
32+
end
33+
if !silent
34+
@info "AbstractMCMC: maximum number of per-chain progress bars set to $max_chains"
35+
end
36+
MAX_CHAINS_PROGRESS[] = max_chains
37+
return nothing
38+
end
39+
2040
function StatsBase.sample(
2141
model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
2242
)
@@ -408,7 +428,7 @@ function mcmcsample(
408428

409429
# Determine default progress bar style.
410430
if progress == true
411-
progress = nchains > 10 ? :overall : :perchain
431+
progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain
412432
elseif progress == false
413433
progress = :none
414434
end

0 commit comments

Comments
 (0)