Skip to content

Need alternative to NamedTuple for SimpleVarInfo #528

@torfjelde

Description

@torfjelde

Problem

Now that we properly support usage of different sizes in the underlying storage of the varinfo after linking, the current usage of NamedTuple for both the "ground truth" in TestUtils, e.g.

function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal, kwargs...)
for vn in vns
@test isequal(vi[vn], get(vals, vn); kwargs...)
end
end

and in SimpleVarInfo, makes less sense than it did before.

To see why, let's consider the following example:

julia> using DynamicPPL, Distributions

julia> @model function demo()
           x = Vector{Float64}(undef, 5)
           x[1] ~ Normal()
           x[2:3] ~ Dirichlet([1.0, 2.0])

           return (x=x,)
       end
demo (generic function with 2 methods)

julia> model = demo();

julia> nt = model()
(x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],)

julia> # Construct `SimpleVarInfo` from `nt`.
       vi = SimpleVarInfo(nt)
SimpleVarInfo((x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],), 0.0)

julia> vn = @varname(x[2:3])
x[2:3]

julia> # (✓) Everything works nicely
       vi[vn]
2-element Vector{Float64}:
 0.6662187241949805
 0.3337812758050194

julia> # Now we link it!
       vi_linked = DynamicPPL.link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 1 elements to 2 destinations
Stacktrace:
  [1] throw_setindex_mismatch(X::Vector{Float64}, I::Tuple{Int64})
    @ Base ./indices.jl:191
  [2] setindex_shape_check
    @ ./indices.jl:245 [inlined]
  [3] setindex!
    @ ./array.jl:994 [inlined]
  [4] _setindex!
    @ ~/.julia/packages/BangBang/FUkah/src/base.jl:480 [inlined]
  [5] may
    @ ~/.julia/packages/BangBang/FUkah/src/core.jl:9 [inlined]
  [6] setindex!!
    @ ~/.julia/packages/BangBang/FUkah/src/base.jl:478 [inlined]
  [7] set(obj::Vector{Float64}, lens::BangBang.SetfieldImpl.Lens!!{Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, value::Vector{Float64})
    @ BangBang.SetfieldImpl ~/.julia/packages/BangBang/FUkah/src/setfield.jl:34
  [8] set
    @ ~/.julia/packages/Setfield/PdKfV/src/lens.jl:188 [inlined]
  [9] set
    @ ~/.julia/packages/BangBang/FUkah/src/setfield.jl:17 [inlined]
 [10] set!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/utils.jl:354 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Setfield/PdKfV/src/sugar.jl:197 [inlined]
 [12] setindex!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/simple_varinfo.jl:339 [inlined]
 [13] tilde_assume(#unused#::DynamicPPL.DynamicTransformationContext{false}, right::Dirichlet{Float64, Vector{Float64}, Float64}, vn::VarName{:x, Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation})
    @ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:19
 [14] tilde_assume!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/context_implementations.jl:117 [inlined]
 [15] demo(__model__::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext}, __varinfo__::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, __context__::DynamicPPL.DynamicTransformationContext{false})
    @ Main ./REPL[47]:4
 [16] _evaluate!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:963 [inlined]
 [17] evaluate_threadunsafe!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:936 [inlined]
 [18] evaluate!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:889 [inlined]
 [19] link!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:86 [inlined]
 [20] link!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:384 [inlined]
 [21] link!!(vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, model::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext})
    @ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:378
 [22] top-level scope
    @ REPL[53]:2

The issue here can really just be boiled down to the fact that we're trying to use the varname

julia> vn
x[2:3]

to index a NamedTuple which is after the transformation represented by a 1-length vector rather than 2-length vector.

In contrast, SimpeVarInfo{<:AbstractDict} will work just fine because here each varname gets its own entry:

julia> # Construct `SimpleVarInfo` using a dict now.
       vi = SimpleVarInfo(rand(OrderedDict, model))
SimpleVarInfo(OrderedDict{Any, Any}(x[1] => -0.12337922752695839, x[2:3] => [0.7836009759179734, 0.21639902408202646]), 0)
julia> # (✓) Everything works nicely
       vi[vn]
2-element Vector{Float64}:
 0.7836009759179734
 0.21639902408202646

julia> # Now we link it!
       vi_linked = DynamicPPL.link!!(vi, model);

julia> # (✓) Everything works nicely
         vi_linked[vn]
1-element Vector{Float64}:
 1.2867758943235161

"Luckily" it has always been the plan that SimpleVarInfo should be able to use different underlying representations fairly easily, e.g. I've successfully used it with ComponentVector from ComponentArrays.jl many times before. And so we should probably find a more flexible default representation for SimpleVarInfo that can be used in more cases.

Solution

Option 1: Use OrderedDict by default

This one is obviously not great becuase of performance reasons, but it will "just work" in all cases and it's very simple to reason about.

Option 2: Dict-like flattened representation

In an ideal world, the underlying representation of the values in a varinfo would have the following properties:

  1. It's type-stable, when possible.
  2. It's contiguous in memory, when possible.
  3. It's indexable by VarName.

Something like an OrderedDict fails in two regards:

  1. It's not contiguous in memory.
  2. Type-stability is not guaranteed, unless we create a dictionary for each eltype or something similar.

Current Metadata used by VarInfo

The Metadata type in VarInfo is a good example of something that satisfies all three properties (of course, the "when possible" in Property (1) is not concrete, but VarInfo uses a NamedTuple of Metadata to achieve this in most common use-cases).

As a reminder, here is what Metadata looks like:

struct Metadata{
TIdcs<:Dict{<:VarName,Int},
TDists<:AbstractVector{<:Distribution},
TVN<:AbstractVector{<:VarName},
TVal<:AbstractVector{<:Real},
TGIds<:AbstractVector{Set{Selector}},
}
# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists`
idcs::TIdcs # Dict{<:VarName,Int}
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn`
vns::TVN # AbstractVector{<:VarName}
# Vector of index ranges in `vals` corresponding to `vns`
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals`
ranges::Vector{UnitRange{Int}}
# Vector of values of all the univariate, multivariate and matrix variables
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]`
vals::TVal # AbstractVector{<:Real}
# Vector of distributions correpsonding to `vns`
dists::TDists # AbstractVector{<:Distribution}
# Vector of sampler ids corresponding to `vns`
# Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set`
gids::TGIds # AbstractVector{Set{Selector}}
# Number of `observe` statements before each random variable is sampled
orders::Vector{Int}
# Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]`
flags::Dict{String,BitVector}
end

Most importantly for a dict-like storage of values, are the following lines:

# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists`
idcs::TIdcs # Dict{<:VarName,Int}
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn`
vns::TVN # AbstractVector{<:VarName}
# Vector of index ranges in `vals` corresponding to `vns`
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals`
ranges::Vector{UnitRange{Int}}
# Vector of values of all the univariate, multivariate and matrix variables
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]`
vals::TVal # AbstractVector{<:Real}

With this, it's fairly easy to implement nice indexing behavior for VarInfo. Here's a simple sketch of what a getindex could look like for Metadata:

function Base.getindex(metadata::Metadata, varname::VarName)
    # Get the index for this `varname`
    idx = metadata.idcs[varname]
    # Get the range for this `varname`
    r = metadata.ranges[idx]
    # Extract the value.
    return metadata.values[r]
end

This is effectively the getval currently implemented:

https://github.com/TuringLang/DynamicPPL.jl/blob/e05bb0935a1e1a06027c603cc04c20f23195a6c4/src/varinfo.jl#L318C44-L318C44

This then results in a Vector of the flattened representation of vn.

Our current implementation of Base.getindex for VarInfo then contains more complexity to convert the Vector back into the original form expected by corresponding distribution, and it's usage looks like

varinfo[varname, dist]

Since the dist is also stored in the Metadata, the above in fact works the same if you do varinfo[varname] if varinfo isa VarInfo and not a SimpleVarInfo. But, as have been discussed many times before, this is not great because it doesn't properly handle dynamics constraints, etc.; we want to use the dist at the place of index, not from the construction of the varinfo.

Nonetheless, value-storage part of Metadata arguably proves quite a nice way to store values in a dict-like way while satisfying the three properties above.

So why not just use Metadata?

Well, we probably should be. But if we're doing so, we should probably simplify its structure quite a bit.

For example, should we drop the following fields?

  • dists: As mentioned, this is often not the correct thing to use.
  • gids: This is used by the Gibbs sampler, and will at some point not be of use anymore since we now have ways of programmatically conditioning and deconditioning models.
  • orders: Only used by particle methods to keep track of the number of observe statements hit. This should probably either be moved somewhere else or at least not be hardcoded into the "main" dict-like object.
  • flags: this might be generally useful, but the flags current used (istrans and delete) are no longer that useful (istrans should be replaced by explicit transformations, as is done in SimpleVarInfo, and delete should also no longer be needed as now have a clear way of indicating whether we're running a model in "sampling mode" or not using SamplingContext).

But the problem of doing this, is that we'll break a lot of code currently dependent on VarInfo functioning as is.
This is also the main reason why we introduced SimpleVarInfo: to allow us to create simpler and different representations of varinfos without breaking existing code.

So what should we do?

For now, it might be a good idea to just introduce a type very similar to Metadata but simpler in its form, i.e. mainly just a value container.

We could either implement out own, or we could see if there are existing implementations in the ecosystem that could benefit us, e.g. Dictionaries.jl seems like it might be suitable.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions