-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Problem
Now that we properly support usage of different sizes in the underlying storage of the varinfo after linking, the current usage of NamedTuple
for both the "ground truth" in TestUtils, e.g.
DynamicPPL.jl/src/test_utils.jl
Lines 33 to 37 in 549d9b1
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal, kwargs...) | |
for vn in vns | |
@test isequal(vi[vn], get(vals, vn); kwargs...) | |
end | |
end |
and in SimpleVarInfo
, makes less sense than it did before.
To see why, let's consider the following example:
julia> using DynamicPPL, Distributions
julia> @model function demo()
x = Vector{Float64}(undef, 5)
x[1] ~ Normal()
x[2:3] ~ Dirichlet([1.0, 2.0])
return (x=x,)
end
demo (generic function with 2 methods)
julia> model = demo();
julia> nt = model()
(x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],)
julia> # Construct `SimpleVarInfo` from `nt`.
vi = SimpleVarInfo(nt)
SimpleVarInfo((x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],), 0.0)
julia> vn = @varname(x[2:3])
x[2:3]
julia> # (✓) Everything works nicely
vi[vn]
2-element Vector{Float64}:
0.6662187241949805
0.3337812758050194
julia> # Now we link it!
vi_linked = DynamicPPL.link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 1 elements to 2 destinations
Stacktrace:
[1] throw_setindex_mismatch(X::Vector{Float64}, I::Tuple{Int64})
@ Base ./indices.jl:191
[2] setindex_shape_check
@ ./indices.jl:245 [inlined]
[3] setindex!
@ ./array.jl:994 [inlined]
[4] _setindex!
@ ~/.julia/packages/BangBang/FUkah/src/base.jl:480 [inlined]
[5] may
@ ~/.julia/packages/BangBang/FUkah/src/core.jl:9 [inlined]
[6] setindex!!
@ ~/.julia/packages/BangBang/FUkah/src/base.jl:478 [inlined]
[7] set(obj::Vector{Float64}, lens::BangBang.SetfieldImpl.Lens!!{Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, value::Vector{Float64})
@ BangBang.SetfieldImpl ~/.julia/packages/BangBang/FUkah/src/setfield.jl:34
[8] set
@ ~/.julia/packages/Setfield/PdKfV/src/lens.jl:188 [inlined]
[9] set
@ ~/.julia/packages/BangBang/FUkah/src/setfield.jl:17 [inlined]
[10] set!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/utils.jl:354 [inlined]
[11] macro expansion
@ ~/.julia/packages/Setfield/PdKfV/src/sugar.jl:197 [inlined]
[12] setindex!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/simple_varinfo.jl:339 [inlined]
[13] tilde_assume(#unused#::DynamicPPL.DynamicTransformationContext{false}, right::Dirichlet{Float64, Vector{Float64}, Float64}, vn::VarName{:x, Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation})
@ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:19
[14] tilde_assume!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/context_implementations.jl:117 [inlined]
[15] demo(__model__::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext}, __varinfo__::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, __context__::DynamicPPL.DynamicTransformationContext{false})
@ Main ./REPL[47]:4
[16] _evaluate!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:963 [inlined]
[17] evaluate_threadunsafe!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:936 [inlined]
[18] evaluate!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:889 [inlined]
[19] link!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:86 [inlined]
[20] link!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:384 [inlined]
[21] link!!(vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, model::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext})
@ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:378
[22] top-level scope
@ REPL[53]:2
The issue here can really just be boiled down to the fact that we're trying to use the varname
julia> vn
x[2:3]
to index a NamedTuple
which is after the transformation represented by a 1-length vector rather than 2-length vector.
In contrast, SimpeVarInfo{<:AbstractDict}
will work just fine because here each varname gets its own entry:
julia> # Construct `SimpleVarInfo` using a dict now.
vi = SimpleVarInfo(rand(OrderedDict, model))
SimpleVarInfo(OrderedDict{Any, Any}(x[1] => -0.12337922752695839, x[2:3] => [0.7836009759179734, 0.21639902408202646]), 0)
julia> # (✓) Everything works nicely
vi[vn]
2-element Vector{Float64}:
0.7836009759179734
0.21639902408202646
julia> # Now we link it!
vi_linked = DynamicPPL.link!!(vi, model);
julia> # (✓) Everything works nicely
vi_linked[vn]
1-element Vector{Float64}:
1.2867758943235161
"Luckily" it has always been the plan that SimpleVarInfo
should be able to use different underlying representations fairly easily, e.g. I've successfully used it with ComponentVector
from ComponentArrays.jl many times before. And so we should probably find a more flexible default representation for SimpleVarInfo
that can be used in more cases.
Solution
Option 1: Use OrderedDict
by default
This one is obviously not great becuase of performance reasons, but it will "just work" in all cases and it's very simple to reason about.
Option 2: Dict-like flattened representation
In an ideal world, the underlying representation of the values in a varinfo would have the following properties:
- It's type-stable, when possible.
- It's contiguous in memory, when possible.
- It's indexable by
VarName
.
Something like an OrderedDict
fails in two regards:
- It's not contiguous in memory.
- Type-stability is not guaranteed, unless we create a dictionary for each eltype or something similar.
Current Metadata
used by VarInfo
The Metadata
type in VarInfo
is a good example of something that satisfies all three properties (of course, the "when possible" in Property (1) is not concrete, but VarInfo
uses a NamedTuple
of Metadata
to achieve this in most common use-cases).
As a reminder, here is what Metadata
looks like:
Lines 39 to 72 in e05bb09
struct Metadata{ | |
TIdcs<:Dict{<:VarName,Int}, | |
TDists<:AbstractVector{<:Distribution}, | |
TVN<:AbstractVector{<:VarName}, | |
TVal<:AbstractVector{<:Real}, | |
TGIds<:AbstractVector{Set{Selector}}, | |
} | |
# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` | |
idcs::TIdcs # Dict{<:VarName,Int} | |
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` | |
vns::TVN # AbstractVector{<:VarName} | |
# Vector of index ranges in `vals` corresponding to `vns` | |
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` | |
ranges::Vector{UnitRange{Int}} | |
# Vector of values of all the univariate, multivariate and matrix variables | |
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` | |
vals::TVal # AbstractVector{<:Real} | |
# Vector of distributions correpsonding to `vns` | |
dists::TDists # AbstractVector{<:Distribution} | |
# Vector of sampler ids corresponding to `vns` | |
# Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set` | |
gids::TGIds # AbstractVector{Set{Selector}} | |
# Number of `observe` statements before each random variable is sampled | |
orders::Vector{Int} | |
# Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` | |
flags::Dict{String,BitVector} | |
end |
Most importantly for a dict-like storage of values, are the following lines:
Lines 46 to 58 in e05bb09
# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` | |
idcs::TIdcs # Dict{<:VarName,Int} | |
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` | |
vns::TVN # AbstractVector{<:VarName} | |
# Vector of index ranges in `vals` corresponding to `vns` | |
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` | |
ranges::Vector{UnitRange{Int}} | |
# Vector of values of all the univariate, multivariate and matrix variables | |
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` | |
vals::TVal # AbstractVector{<:Real} |
With this, it's fairly easy to implement nice indexing behavior for VarInfo
. Here's a simple sketch of what a getindex
could look like for Metadata
:
function Base.getindex(metadata::Metadata, varname::VarName)
# Get the index for this `varname`
idx = metadata.idcs[varname]
# Get the range for this `varname`
r = metadata.ranges[idx]
# Extract the value.
return metadata.values[r]
end
This is effectively the getval
currently implemented:
This then results in a Vector
of the flattened representation of vn
.
Our current implementation of Base.getindex
for VarInfo
then contains more complexity to convert the Vector
back into the original form expected by corresponding distribution, and it's usage looks like
varinfo[varname, dist]
Since the dist
is also stored in the Metadata
, the above in fact works the same if you do varinfo[varname]
if varinfo isa VarInfo
and not a SimpleVarInfo
. But, as have been discussed many times before, this is not great because it doesn't properly handle dynamics constraints, etc.; we want to use the dist
at the place of index, not from the construction of the varinfo
.
Nonetheless, value-storage part of Metadata
arguably proves quite a nice way to store values in a dict-like way while satisfying the three properties above.
So why not just use Metadata
?
Well, we probably should be. But if we're doing so, we should probably simplify its structure quite a bit.
For example, should we drop the following fields?
dists
: As mentioned, this is often not the correct thing to use.gids
: This is used by the Gibbs sampler, and will at some point not be of use anymore since we now have ways of programmatically conditioning and deconditioning models.orders
: Only used by particle methods to keep track of the number of observe statements hit. This should probably either be moved somewhere else or at least not be hardcoded into the "main" dict-like object.flags
: this might be generally useful, but the flags current used (istrans
anddelete
) are no longer that useful (istrans
should be replaced by explicit transformations, as is done inSimpleVarInfo
, anddelete
should also no longer be needed as now have a clear way of indicating whether we're running a model in "sampling mode" or not usingSamplingContext
).
But the problem of doing this, is that we'll break a lot of code currently dependent on VarInfo
functioning as is.
This is also the main reason why we introduced SimpleVarInfo
: to allow us to create simpler and different representations of varinfos without breaking existing code.
So what should we do?
For now, it might be a good idea to just introduce a type very similar to Metadata
but simpler in its form, i.e. mainly just a value container.
We could either implement out own, or we could see if there are existing implementations in the ecosystem that could benefit us, e.g. Dictionaries.jl seems like it might be suitable.