Skip to content

Commit fac8641

Browse files
authored
VNT Part 2: Solving the issue of block elements (#1180)
* ArrayLikeBlock WIP * ArrayLikeBlock WIP2 * Improve type stability of ArrayLikeBlock stuff * Test more invariants * Actually run VNT tests * Implement show for ArrayLikeBlock * Change keys on VNT to return an array * Fix keys and some tests for PartialArray * Improve type stability * Fix keys for PartialArray * More ArrayLikeBlock tests * Add docstrings * Remove redundant code, improve documentation * Add Base.size(::RangeAndLinked) * Fix issues with RangeAndLinked and VNT * Write more design doc for ArrayLikeBlocks
1 parent 44be19d commit fac8641

File tree

8 files changed

+464
-48
lines changed

8 files changed

+464
-48
lines changed

docs/src/internals/varnamedtuple.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,29 @@ You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`,
144144
At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`.
145145
The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is.
146146

147+
## Non-Array blocks with `IndexLens`es
148+
149+
The above is all that is needed for setting regular scalar values.
150+
However, in DynamicPPL we also have a particular need for something slightly odd:
151+
We sometimes need to do calls like `setindex!!(vnt, @varname(a[1:5]), val)` on a `val` that is _not_ an `AbstractArray`, or even iterable at all.
152+
Normally this would error: As a scalar value with size `()`, `val` is the wrong size to be set with `@varname(a[1:5])`, which clearly wants something with size `(5,)`.
153+
However, we want to allow this even if `val` is not an iterable, if it is some object for which `size` is well-defined, and `size(val) == (5,)`.
154+
In DynamicPPL this comes up when storing e.g. the priors of a model, where a random variable like `@varname(a[1:5])` may be associated with a prior that is a 5-dimensional distribution.
155+
156+
Internally, a `PartialArray` is just a regular `Array` with a mask saying which elements have been set.
157+
Hence we can't store `val` directly in the same `PartialArray`:
158+
We need it to take up a sub-block of the array, in our example case a sub-block of length 5.
159+
To this end, internally, `PartialArray` uses a wrapper type called `ArrayLikeWrapper`, that stores `val` together with the indices that are being used to set it.
160+
The `PartialArray` has all its corresponding elements, in our example elements 1, 2, 3, 4, and, 5, point to the same wrapper object.
161+
162+
While such blocks can be stored using a wrapper like this, some care must be taken in indexing into these blocks.
163+
For instance, after setting a block with `setindex!!(vnt, @varname(a[1:5]), val)`, we can't `getindex(vnt, @varname(a[1]))`, since we can't return "the first element of five in `val`", because `val` may not be indexable in any way.
164+
Similarly, if next we set `setindex!!(vnt, @varname(a[1]), some_other_value)`, that should invalidate/delete the elements `@varname(a[2:5])`, since the block only makes sense as a whole.
165+
Because of these reasons, setting and getting blocks of well-defined size like this is allowed with `VarNamedTuple`s, but _only by always using the full range_.
166+
For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the only valid `getindex` key to access `val` is `@varname(a[1:5])`;
167+
Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`.
168+
`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element.
169+
147170
## Limitations
148171

149172
This design has a several of benefits, for performance and generality, but it also has limitations:

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module DynamicPPLMarginalLogDensitiesExt
22

3-
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
3+
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
44
using MarginalLogDensities: MarginalLogDensities
55

66
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
@@ -105,11 +105,9 @@ function DynamicPPL.marginalize(
105105
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
106106
# Determine the indices for the variables to marginalise out.
107107
varindices = mapreduce(vcat, marginalized_varnames) do vn
108-
if DynamicPPL.getoptic(vn) === identity
109-
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
110-
else
111-
ldf._varname_ranges[vn].range
112-
end
108+
# The type assertion helps in cases where the model is type unstable and thus
109+
# `varname_ranges` may have an abstract element type.
110+
(ldf._varname_ranges[vn]::RangeAndLinked).range
113111
end
114112
mld = MarginalLogDensities.MarginalLogDensity(
115113
LogDensityFunctionWrapper(ldf, varinfo),

src/contexts/init.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,17 @@ an unlinked value.
206206
207207
$(TYPEDFIELDS)
208208
"""
209-
struct RangeAndLinked
209+
struct RangeAndLinked{T<:Tuple}
210210
# indices that the variable corresponds to in the vectorised parameter
211211
range::UnitRange{Int}
212212
# whether it's linked
213213
is_linked::Bool
214+
# original size of the variable before vectorisation
215+
original_size::T
214216
end
215217

218+
Base.size(ral::RangeAndLinked) = ral.original_size
219+
216220
"""
217221
VectorWithRanges{Tlink}(
218222
varname_ranges::VarNamedTuple,
@@ -247,7 +251,12 @@ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}}
247251
end
248252

249253
function _get_range_and_linked(vr::VectorWithRanges, vn::VarName)
250-
return vr.varname_ranges[vn]
254+
# The type assertion does nothing if VectorWithRanges has concrete element types, as is
255+
# the case for all type stable models. However, if the model is not type stable,
256+
# vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to
257+
# assert that it is a RangeAndLinked, because even though it remains non-concrete,
258+
# it'll allow the compiler to infer the types of `range` and `is_linked`.
259+
return vr.varname_ranges[vn]::RangeAndLinked
251260
end
252261
function init(
253262
::Random.AbstractRNG,

src/logdensityfunction.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int)
330330
for (vn, idx) in md.idcs
331331
is_linked = md.is_transformed[idx]
332332
range = md.ranges[idx] .+ (start_offset - 1)
333-
all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn)
333+
orig_size = varnamesize(vn)
334+
all_ranges = BangBang.setindex!!(
335+
all_ranges, RangeAndLinked(range, is_linked, orig_size), vn
336+
)
334337
offset += length(range)
335338
end
336339
return all_ranges, offset
@@ -341,7 +344,10 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
341344
for (vn, idx) in vnv.varname_to_index
342345
is_linked = vnv.is_unconstrained[idx]
343346
range = vnv.ranges[idx] .+ (start_offset - 1)
344-
all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn)
347+
orig_size = varnamesize(vn)
348+
all_ranges = BangBang.setindex!!(
349+
all_ranges, RangeAndLinked(range, is_linked, orig_size), vn
350+
)
345351
offset += length(range)
346352
end
347353
return all_ranges, offset

src/varname.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,28 @@ Possibly existing indices of `varname` are neglected.
4141
) where {s,missings,_F,_a,_T}
4242
return s in missings
4343
end
44+
45+
# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL.
46+
"""
47+
varnamesize(vn::VarName)
48+
49+
Return the size of the object referenced by this VarName.
50+
51+
```jldoctest
52+
julia> varnamesize(@varname(a))
53+
()
54+
55+
julia> varnamesize(@varname(b[1:3, 2]))
56+
(3,)
57+
58+
julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1]))
59+
(4, 4)
60+
"""
61+
function varnamesize(vn::VarName)
62+
l = AbstractPPL._last(vn.optic)
63+
if l isa Accessors.IndexLens
64+
return reduce((x, y) -> tuple(x..., y...), map(size, l.indices))
65+
else
66+
return ()
67+
end
68+
end

0 commit comments

Comments
 (0)