Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.8.0"
version = "5.8.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
32 changes: 32 additions & 0 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ function StatsBase.sample(
return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...)
end

# Utility function to check and warn about common kwargs mistakes
function _check_initial_params_kwarg(kwargs)
if haskey(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
return true
end
return false
end

# Utility function to remove initial_parameters from kwargs after warning
function _filter_initial_params_kwarg(kwargs)
return pairs((; (k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...))
end

# Default implementations of regular and parallel sampling.
function mcmcsample(
rng::Random.AbstractRNG,
Expand All @@ -121,6 +135,9 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params
_check_initial_params_kwarg(kwargs)

# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
discard_initial >= 0 ||
Expand Down Expand Up @@ -405,6 +422,11 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params and remove it from kwargs
if _check_initial_params_kwarg(kwargs)
kwargs = _filter_initial_params_kwarg(kwargs)
end

# Check if actually multiple threads are used.
if Threads.nthreads() == 1
@warn "Only a single thread available: MCMC chains are not sampled in parallel"
Expand Down Expand Up @@ -588,6 +610,11 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params and remove it from kwargs
if _check_initial_params_kwarg(kwargs)
kwargs = _filter_initial_params_kwarg(kwargs)
end

# Check if actually multiple processes are used.
if Distributed.nworkers() == 1
@warn "Only a single process available: MCMC chains are not sampled in parallel"
Expand Down Expand Up @@ -727,6 +754,11 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params and remove it from kwargs
if _check_initial_params_kwarg(kwargs)
kwargs = _filter_initial_params_kwarg(kwargs)
end

# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
Expand Down
39 changes: 39 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@
)
@test chain[1].a == -1.8
@test chain[1].b == 3.2

# test warning for initial_parameters (typo) in single-chain sampling
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
MyModel(), MySampler(), 3; progress=false, initial_parameters=(b=1.0, a=2.0)
)

# test warning for initial_parameters (typo) in multi-chain sampling
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") match_mode =
:any sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
2;
progress=false,
initial_parameters=(b=1.0, a=2.0),
)
end

@testset "IJulia" begin
Expand Down Expand Up @@ -282,6 +299,17 @@
MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain
)

# Test warning for initial_parameters (typo)
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
2;
progress=false,
initial_parameters=(b=1.0, a=2.0),
)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(
Expand Down Expand Up @@ -408,6 +436,17 @@
MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain
)

# Test warning for initial_parameters (typo)
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
2;
progress=false,
initial_parameters=(b=1.0, a=2.0),
)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(
Expand Down
Loading