Skip to content

Commit ae9760d

Browse files
Show warning message if initial_parameters is passed instead of initial_params (#176)
* Show warning message if initial_parameters is passed instead of initial_params * run JuliaFormatter * fix CI * remove initial_parameters from kwargs in multi-chain functions * format * remove reduntant comment * minor version bump * add utility function for initial_parameters warning
1 parent fc74dd4 commit ae9760d

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.8.0"
6+
version = "5.8.1"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ function StatsBase.sample(
105105
return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...)
106106
end
107107

108+
# Utility function to check and warn about common kwargs mistakes
109+
function _check_initial_params_kwarg(kwargs)
110+
if haskey(kwargs, :initial_parameters)
111+
@warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead."
112+
return true
113+
end
114+
return false
115+
end
116+
117+
# Utility function to remove initial_parameters from kwargs after warning
118+
function _filter_initial_params_kwarg(kwargs)
119+
return pairs((; (k => v for (k, v) in pairs(kwargs) if k !== :initial_parameters)...))
120+
end
121+
108122
# Default implementations of regular and parallel sampling.
109123
function mcmcsample(
110124
rng::Random.AbstractRNG,
@@ -121,6 +135,9 @@ function mcmcsample(
121135
initial_state=nothing,
122136
kwargs...,
123137
)
138+
# Warn if initial_parameters is passed instead of initial_params
139+
_check_initial_params_kwarg(kwargs)
140+
124141
# Check the number of requested samples.
125142
N > 0 || error("the number of samples must be ≥ 1")
126143
discard_initial >= 0 ||
@@ -405,6 +422,11 @@ function mcmcsample(
405422
initial_state=nothing,
406423
kwargs...,
407424
)
425+
# Warn if initial_parameters is passed instead of initial_params and remove it from kwargs
426+
if _check_initial_params_kwarg(kwargs)
427+
kwargs = _filter_initial_params_kwarg(kwargs)
428+
end
429+
408430
# Check if actually multiple threads are used.
409431
if Threads.nthreads() == 1
410432
@warn "Only a single thread available: MCMC chains are not sampled in parallel"
@@ -588,6 +610,11 @@ function mcmcsample(
588610
initial_state=nothing,
589611
kwargs...,
590612
)
613+
# Warn if initial_parameters is passed instead of initial_params and remove it from kwargs
614+
if _check_initial_params_kwarg(kwargs)
615+
kwargs = _filter_initial_params_kwarg(kwargs)
616+
end
617+
591618
# Check if actually multiple processes are used.
592619
if Distributed.nworkers() == 1
593620
@warn "Only a single process available: MCMC chains are not sampled in parallel"
@@ -727,6 +754,11 @@ function mcmcsample(
727754
initial_state=nothing,
728755
kwargs...,
729756
)
757+
# Warn if initial_parameters is passed instead of initial_params and remove it from kwargs
758+
if _check_initial_params_kwarg(kwargs)
759+
kwargs = _filter_initial_params_kwarg(kwargs)
760+
end
761+
730762
# Check if the number of chains is larger than the number of samples
731763
if nchains > N
732764
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"

test/sample.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,23 @@
3131
)
3232
@test chain[1].a == -1.8
3333
@test chain[1].b == 3.2
34+
35+
# test warning for initial_parameters (typo) in single-chain sampling
36+
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
37+
MyModel(), MySampler(), 3; progress=false, initial_parameters=(b=1.0, a=2.0)
38+
)
39+
40+
# test warning for initial_parameters (typo) in multi-chain sampling
41+
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") match_mode =
42+
:any sample(
43+
MyModel(),
44+
MySampler(),
45+
MCMCThreads(),
46+
3,
47+
2;
48+
progress=false,
49+
initial_parameters=(b=1.0, a=2.0),
50+
)
3451
end
3552

3653
@testset "IJulia" begin
@@ -282,6 +299,17 @@
282299
MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain
283300
)
284301

302+
# Test warning for initial_parameters (typo)
303+
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
304+
MyModel(),
305+
MySampler(),
306+
MCMCDistributed(),
307+
3,
308+
2;
309+
progress=false,
310+
initial_parameters=(b=1.0, a=2.0),
311+
)
312+
285313
# Suppress output.
286314
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
287315
sample(
@@ -408,6 +436,17 @@
408436
MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain
409437
)
410438

439+
# Test warning for initial_parameters (typo)
440+
@test_logs (:warn, r"initial_parameters.*not recognised.*initial_params") sample(
441+
MyModel(),
442+
MySampler(),
443+
MCMCSerial(),
444+
3,
445+
2;
446+
progress=false,
447+
initial_parameters=(b=1.0, a=2.0),
448+
)
449+
411450
# Suppress output.
412451
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
413452
sample(

0 commit comments

Comments
 (0)