Skip to content

Commit 0795c75

Browse files
committed
Add more tests
1 parent c3ba043 commit 0795c75

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/sampler.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ function _convert_initial_params(d::AbstractDict{<:VarName})
6868
return InitFromParams(d)
6969
end
7070
function _convert_initial_params(::AbstractVector)
71-
return error(
72-
"`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally an `AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code.",
73-
)
71+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally an `AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
72+
throw(ArgumentError(errmsg))
7473
end
7574

7675
function AbstractMCMC.sample(

test/sampler.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@
138138
end
139139
end
140140

141+
# check that Vector no longer works
142+
@test_throws ArgumentError sample(
143+
model, sampler, 1; initial_params=[4, -1], progress=false
144+
)
145+
@test_throws ArgumentError sample(
146+
model, sampler, 1; initial_params=[missing, -1], progress=false
147+
)
148+
141149
# model with two variables: initialization s = 4, m = -1
142150
@model function twovars()
143151
s ~ InverseGamma(2, 3)
@@ -181,7 +189,8 @@
181189
Dict(@varname(s) => missing, @varname(m) => -1),
182190
InitFromParams((; m=-1)),
183191
InitFromParams(Dict(@varname(m) => -1)),
184-
(; m=-1)Dict(@varname(m) => -1),
192+
(; m=-1),
193+
Dict(@varname(m) => -1),
185194
)
186195
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
187196
@test !ismissing(chain[1].metadata.s.vals[1])

0 commit comments

Comments
 (0)