Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.0"
version = "0.33.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
32 changes: 31 additions & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
"""
initialsampler(spl::Sampler) = SampleFromPrior()

function set_values!!(
varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler
)
throw(
ArgumentError(
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
"If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
),
)
end

function set_values!!(
varinfo::AbstractVarInfo,
initial_params::AbstractVector{<:Union{Real,Missing}},
Expand All @@ -164,7 +175,8 @@ function set_values!!(
flattened_param_vals = varinfo[spl]
length(flattened_param_vals) == length(initial_params) || throw(
DimensionMismatch(
"Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))",
"Provided initial value size ($(length(initial_params))) doesn't match " *
"the model size ($(length(flattened_param_vals))).",
),
)

Expand All @@ -183,6 +195,24 @@ end
function set_values!!(
varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler
)
vars_in_varinfo = keys(varinfo)
for v in keys(initial_params)
vn = VarName{v}()
if !(vn in vars_in_varinfo)
for vv in vars_in_varinfo
if subsumes(vn, vv)
throw(
ArgumentError(
"The current model contains sub-variables of $v, such as ($vv). " *
"Using NamedTuple for initial_params is not supported in such a case. " *
"Please use AbstractVector for initial_params instead of NamedTuple.",
),
)
end
end
throw(ArgumentError("Variable $v not found in the model."))
end
end
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
return update_values!!(
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
Expand Down
25 changes: 25 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,30 @@
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
end
end

@testset "error handling" begin
# https://github.com/TuringLang/Turing.jl/issues/2452
@model function constrained_uniform(n)
Z ~ Uniform(10, 20)
X = Vector{Float64}(undef, n)
for i in 1:n
X[i] ~ Uniform(0, Z)
end
end

n = 2
initial_z = 15
initial_x = [0.2, 0.5]
model = constrained_uniform(n)
vi = VarInfo(model)

@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model
)

@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model
)
end
end
end
Loading