diff --git a/Project.toml b/Project.toml index 3580425a..f49b214a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.8.0" +version = "5.8.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 913332d1..13e91b7a 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -105,6 +105,20 @@ function StatsBase.sample( return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...) end +# Utility function to check and warn about common kwargs mistakes +function _check_initial_params_kwarg(kwargs) + if haskey(kwargs, :initial_parameters) + @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." + return true + end + return false +end + +# Utility function to remove initial_parameters from kwargs after warning +function _filter_initial_params_kwarg(kwargs) + return pairs((; (k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...)) +end + # Default implementations of regular and parallel sampling. function mcmcsample( rng::Random.AbstractRNG, @@ -121,6 +135,9 @@ function mcmcsample( initial_state=nothing, kwargs..., ) + # Warn if initial_parameters is passed instead of initial_params + _check_initial_params_kwarg(kwargs) + # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") discard_initial >= 0 || @@ -405,6 +422,11 @@ function mcmcsample( initial_state=nothing, kwargs..., ) + # Warn if initial_parameters is passed instead of initial_params and remove it from kwargs + if _check_initial_params_kwarg(kwargs) + kwargs = _filter_initial_params_kwarg(kwargs) + end + # Check if actually multiple threads are used. if Threads.nthreads() == 1 @warn "Only a single thread available: MCMC chains are not sampled in parallel" @@ -588,6 +610,11 @@ function mcmcsample( initial_state=nothing, kwargs..., ) + # Warn if initial_parameters is passed instead of initial_params and remove it from kwargs + if _check_initial_params_kwarg(kwargs) + kwargs = _filter_initial_params_kwarg(kwargs) + end + # Check if actually multiple processes are used. if Distributed.nworkers() == 1 @warn "Only a single process available: MCMC chains are not sampled in parallel" @@ -727,6 +754,11 @@ function mcmcsample( initial_state=nothing, kwargs..., ) + # Warn if initial_parameters is passed instead of initial_params and remove it from kwargs + if _check_initial_params_kwarg(kwargs) + kwargs = _filter_initial_params_kwarg(kwargs) + end + # Check if the number of chains is larger than the number of samples if nchains > N @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" diff --git a/test/sample.jl b/test/sample.jl index f561f535..954acc7c 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -31,6 +31,23 @@ ) @test chain[1].a == -1.8 @test chain[1].b == 3.2 + + # test warning for initial_parameters (typo) in single-chain sampling + @test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample( + MyModel(), MySampler(), 3; progress=false, initial_parameters=(b=1.0, a=2.0) + ) + + # test warning for initial_parameters (typo) in multi-chain sampling + @test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") match_mode = + :any sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + 2; + progress=false, + initial_parameters=(b=1.0, a=2.0), + ) end @testset "IJulia" begin @@ -282,6 +299,17 @@ MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain ) + # Test warning for initial_parameters (typo) + @test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + 2; + progress=false, + initial_parameters=(b=1.0, a=2.0), + ) + # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do sample( @@ -408,6 +436,17 @@ MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain ) + # Test warning for initial_parameters (typo) + @test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + 2; + progress=false, + initial_parameters=(b=1.0, a=2.0), + ) + # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do sample(