Skip to content

Commit bf5aeb4

Browse files
committed
improve error message for initial_params
1 parent 3d18cfc commit bf5aeb4

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

src/sampler.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,22 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
157157
initialsampler(spl::Sampler) = SampleFromPrior()
158158

159159
function set_values!!(
160-
varinfo::AbstractVarInfo,
161-
initial_params::AbstractVector{<:Union{Real,Missing}},
162-
spl::AbstractSampler,
163-
)
160+
varinfo::AbstractVarInfo, initial_params::AbstractVector{T}, spl::AbstractSampler
161+
) where {T}
162+
if T === Any
163+
throw(
164+
ArgumentError(
165+
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
166+
"If `initial_params` is a vector of vectors, please flatten it first using `vcat`.",
167+
),
168+
)
169+
end
170+
164171
flattened_param_vals = varinfo[spl]
165172
length(flattened_param_vals) == length(initial_params) || throw(
166173
DimensionMismatch(
167-
"Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))",
174+
"Provided initial value size ($(length(initial_params))) doesn't match " *
175+
"the model size ($(length(flattened_param_vals))).",
168176
),
169177
)
170178

@@ -183,6 +191,23 @@ end
183191
function set_values!!(
184192
varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler
185193
)
194+
vars_in_varinfo = keys(varinfo)
195+
for v in keys(initial_params)
196+
if !(v in vars_in_varinfo)
197+
for vv in vars_in_varinfo
198+
if subsumes(VarName{v}(), vv)
199+
throw(
200+
ArgumentError(
201+
"Variable $v not found in model, but it subsumes a variable ($vv) in the model. " *
202+
"Please use AbstractVector for initial_params instead of NamedTuple.",
203+
),
204+
)
205+
end
206+
end
207+
208+
throw(ArgumentError("Variable $v not found in the model."))
209+
end
210+
end
186211
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
187212
return update_values!!(
188213
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))

test/sampler.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,30 @@
178178
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
179179
end
180180
end
181+
182+
@testset "error handling" begin
183+
# https://github.com/TuringLang/Turing.jl/issues/2452
184+
@model function constrained_uniform(n)
185+
Z ~ Uniform(10, 20)
186+
X = Vector{Float64}(undef, n)
187+
for i in 1:n
188+
X[i] ~ Uniform(0, Z)
189+
end
190+
end
191+
192+
n = 2
193+
initial_z = 15
194+
initial_x = [0.2, 0.5]
195+
model = constrained_uniform(n)
196+
vi = VarInfo(model)
197+
198+
@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
199+
vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model
200+
)
201+
202+
@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
203+
vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model
204+
)
205+
end
181206
end
182207
end

0 commit comments

Comments
 (0)