Skip to content

Commit 97df07f

Browse files
committed
Initial attempt at hasvalue(vals, vn, dist)
1 parent d0868a4 commit 97df07f

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

src/sampler.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,6 @@ Default type of the chain of posterior samples from `sampler`.
8888
"""
8989
default_chain_type(sampler::Sampler) = Any
9090

91-
"""
92-
init_strategy(sampler)
93-
94-
Define the initialisation strategy used for generating initial values when
95-
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
96-
"""
97-
init_strategy(::Sampler) = PriorInit()
98-
9991
"""
10092
initialstep(rng, model, sampler, varinfo; kwargs...)
10193

src/utils.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ end
845845

846846
# For `dictlike` we need to check wether `vn` is "immediately" present, or
847847
# if some ancestor of `vn` is present in `dictlike`.
848-
function hasvalue(vals::AbstractDict, vn::VarName)
848+
function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName)
849849
# First we check if `vn` is present as is.
850850
haskey(vals, vn) && return true
851851

@@ -867,6 +867,39 @@ function hasvalue(vals::AbstractDict, vn::VarName)
867867

868868
return canview(child, value)
869869
end
870+
# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
871+
function hasvalue(vals::AbstractDict, vn::VarName, dist::Distribution)
872+
@warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
873+
return hasvalue(vals, vn)
874+
end
875+
hasvalue(vals::AbstractDict, vn::VarName, ::UnivariateDistribution) = hasvalue(vals, vn)
876+
function hasvalue(
877+
vals::AbstractDict{<:VarName},
878+
vn::VarName{sym},
879+
dist::Union{MultivariateDistribution,MatrixDistribution},
880+
) where {sym}
881+
# If `vn` is present as-is, then we are good
882+
hasvalue(vals, vn) && return true
883+
# If not, then we need to check inside `vals` to see if a subset of
884+
# `vals` is enough to reconstruct `vn`. For example, if `vals` contains
885+
# `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
886+
# can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
887+
# can't.
888+
# To do this, we get the size of the distribution and iterate over all
889+
# possible indices. If every index can be found in `subsumed_keys`, then we
890+
# can return true.
891+
sz = size(dist)
892+
for idx in Iterators.product(map(Base.OneTo, sz)...)
893+
new_optic = if getoptic(vn) === identity
894+
Accessors.IndexLens(idx)
895+
else
896+
Accessors.IndexLens(idx) getoptic(vn)
897+
end
898+
new_vn = VarName{sym}(new_optic)
899+
hasvalue(vals, new_vn) || return false
900+
end
901+
return true
902+
end
870903

871904
"""
872905
nested_getindex(values::AbstractDict, vn::VarName)

0 commit comments

Comments
 (0)