Skip to content

Commit d6fe92d

Browse files
committed
Check if the number of chains is greater than the number of samples per chain
1 parent c92a866 commit d6fe92d

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/sample.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ function mcmcsample(
197197
@warn "Only a single thread available: MCMC chains are not sampled in parallel"
198198
end
199199

200+
# Check if the number of chains is larger than the number of samples
201+
if nchains > N
202+
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
203+
end
204+
200205
# Copy the random number generator, model, and sample for each thread
201206
rngs = [deepcopy(rng) for _ in 1:Threads.nthreads()]
202207
models = [deepcopy(model) for _ in 1:Threads.nthreads()]
@@ -269,6 +274,11 @@ function mcmcsample(
269274
@warn "Only a single process available: MCMC chains are not sampled in parallel"
270275
end
271276

277+
# Check if the number of chains is larger than the number of samples
278+
if nchains > N
279+
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
280+
end
281+
272282
# Create a seed for each chain using the provided random number generator.
273283
seeds = rand(rng, UInt, nchains)
274284

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ include("interface.jl")
139139

140140
@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
141141

142+
# Unexpected order of arguments.
143+
str = "Number of chains (10) is greater than number of samples per chain (5)"
144+
@test_logs (:warn, str) sample(MyModel(), MySampler(), MCMCThreads(), 5, 10;
145+
chain_type = MyChain)
146+
142147
# Suppress output.
143148
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
144149
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
@@ -189,6 +194,11 @@ include("interface.jl")
189194

190195
@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
191196

197+
# Unexpected order of arguments.
198+
str = "Number of chains (10) is greater than number of samples per chain (5)"
199+
@test_logs (:warn, str) sample(MyModel(), MySampler(), MCMCDistributed(), 5, 10;
200+
chain_type = MyChain)
201+
192202
# Suppress output.
193203
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
194204
sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100;

0 commit comments

Comments
 (0)