Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.8.0"
version = "5.8.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
32 changes: 32 additions & 0 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params
if haskey(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
end

# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
discard_initial >= 0 ||
Expand Down Expand Up @@ -405,6 +410,15 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params
if haskey(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
# Remove initial_parameters from kwargs to prevent it from being passed to single-chain sample
kwargs = pairs((;
(k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...
))
end

# Check if actually multiple threads are used.
if Threads.nthreads() == 1
@warn "Only a single thread available: MCMC chains are not sampled in parallel"
Expand Down Expand Up @@ -588,6 +602,15 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params
if haskey(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
# Remove initial_parameters from kwargs to prevent it from being passed to single-chain sample
kwargs = pairs((;
(k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...
))
end

# Check if actually multiple processes are used.
if Distributed.nworkers() == 1
@warn "Only a single process available: MCMC chains are not sampled in parallel"
Expand Down Expand Up @@ -727,6 +750,15 @@ function mcmcsample(
initial_state=nothing,
kwargs...,
)
# Warn if initial_parameters is passed instead of initial_params
if haskey(kwargs, :initial_parameters)
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
# Remove initial_parameters from kwargs to prevent it from being passed to single-chain sample
kwargs = pairs((;
(k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...
))
end

# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
Expand Down
39 changes: 39 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@
)
@test chain[1].a == -1.8
@test chain[1].b == 3.2

# test warning for initial_parameters (typo) in single-chain sampling
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
MyModel(), MySampler(), 3; progress=false, initial_parameters=(b=1.0, a=2.0)
)

# test warning for initial_parameters (typo) in multi-chain sampling
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") match_mode =
:any sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
2;
progress=false,
initial_parameters=(b=1.0, a=2.0),
)
end

@testset "IJulia" begin
Expand Down Expand Up @@ -282,6 +299,17 @@
MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain
)

# Test warning for initial_parameters (typo)
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
2;
progress=false,
initial_parameters=(b=1.0, a=2.0),
)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(
Expand Down Expand Up @@ -408,6 +436,17 @@
MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain
)

# Test warning for initial_parameters (typo)
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
2;
progress=false,
initial_parameters=(b=1.0, a=2.0),
)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(
Expand Down
Loading