Skip to content

Commit 8b5bd47

Browse files
committed
Improve VarNamedVector docs
1 parent be77c36 commit 8b5bd47

File tree

4 files changed

+164
-19
lines changed

4 files changed

+164
-19
lines changed

src/varinfo.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,20 @@ function BangBang.push!!(
18351835
return vi
18361836
end
18371837

1838+
function Base.push!(vi::VectorVarInfo, vn::VarName, val, args...)
1839+
push!(getmetadata(vi, vn), vn, val, args...)
1840+
return vi
1841+
end
1842+
1843+
function Base.push!(vi::VectorVarInfo, pair::Pair, args...)
1844+
vn, val = pair
1845+
return push!(vi, vn, val, args...)
1846+
end
1847+
1848+
# TODO(mhauru) push! can't be implemented in-place for TypedVarInfo if the symbol doesn't
1849+
# exist in the TypedVarInfo already. We could implement it in the cases where it it does
1850+
# exist, but that feels a bit pointless. I think we should rather rely on `push!!`.
1851+
18381852
function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
18391853
val = tovec(r)
18401854
meta.idcs[vn] = length(meta.idcs) + 1
@@ -1852,6 +1866,11 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
18521866
return meta
18531867
end
18541868

1869+
function Base.delete!(vi::VarInfo, vn::VarName)
1870+
delete!(getmetadata(vi, vn), vn)
1871+
return vi
1872+
end
1873+
18551874
"""
18561875
setorder!(vi::VarInfo, vn::VarName, index::Int)
18571876

src/varnamedvector.jl

Lines changed: 134 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,107 @@
33
44
A container that stores values in a vectorised form, but indexable by variable names.
55
6-
When indexed by integers or `Colon`s, e.g. `vnv[2]` or `vnv[:]`, `VarNamedVector` behaves
6+
When indexed with integers or `Colon`s, e.g. `vnv[2]` or `vnv[:]`, `VarNamedVector` behaves
77
like a `Vector`, and returns the values as they are stored. The stored form is always
88
vectorised, for instance matrix variables have been flattened, and may be further
99
transformed to achieve linking.
1010
11-
When indexed by `VarName`s, e.g. `vnv[@varname(x)]`, `VarNamedVector` returns the values
11+
When indexed with `VarName`s, e.g. `vnv[@varname(x)]`, `VarNamedVector` returns the values
1212
in the original space. For instance, a linked matrix variable is first inverse linked and
1313
then reshaped to its original form before returning it to the caller.
1414
1515
`VarNamedVector` also stores a boolean for whether a variable has been transformed to
1616
unconstrained Euclidean space or not.
1717
18+
Internally, `VarNamedVector` stores the values of all variables in a single contiguous
19+
vector.
20+
1821
# Fields
22+
1923
$(FIELDS)
24+
25+
# Extended help
26+
27+
The values for different variables are internally all stored in a single vector. For
28+
instance,
29+
```jldoctest varnamedvector-struct
30+
julia> using DynamicPPL: VarNamedVector, @varname, push!, update!
31+
32+
julia> vnv = VarNamedVector();
33+
34+
julia> push!(vnv, @varname(x) => [0.0, 1.0]);
35+
36+
julia> push!(vnv, @varname(y) => fill(3, (3,3)));
37+
38+
julia> vnv.vals
39+
11-element Vector{Real}:
40+
0.0
41+
1.0
42+
3
43+
3
44+
3
45+
3
46+
3
47+
3
48+
3
49+
3
50+
3
51+
```
52+
53+
The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to
54+
which variable. The `transforms` field stores the transformations that needed to transform
55+
the vectorised internal storage back to its original form:
56+
57+
```jldoctest varnamedvector-struct
58+
julia> vnv.transforms[vnv.varname_to_index[@varname(y)]]
59+
DynamicPPL.ReshapeTransform{Tuple{Int64, Int64}}((3, 3))
60+
```
61+
62+
If a variable is updated with a new value that is of a smaller dimension than the old
63+
value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive.
64+
65+
```jldoctest varnamedvector-struct
66+
julia> update!(vnv, @varname(y), fill(2, (2, 2)))
67+
68+
julia> vnv.vals
69+
11-element Vector{Real}:
70+
0.0
71+
1.0
72+
2
73+
2
74+
2
75+
2
76+
3
77+
3
78+
3
79+
3
80+
3
81+
82+
julia> vnv.num_inactive
83+
OrderedDict{Int64, Int64} with 1 entry:
84+
2 => 5
85+
```
86+
87+
This helps avoid unnecessary memory allocations for values that repeatedly change dimension.
88+
The user does not have to worry about the inactive entries as long as they use functions
89+
like `update!` and `getindex!` rather than directly accessing `vnv.vals`.
90+
91+
```jldoctest varnamedvector-struct
92+
julia> vnv[@varname(y)]
93+
2×2 Matrix{Real}:
94+
2 2
95+
2 2
96+
97+
98+
julia> vnv[:]
99+
6-element Vector{Real}:
100+
0.0
101+
1.0
102+
2
103+
2
104+
2
105+
2
106+
```
20107
"""
21108
struct VarNamedVector{
22109
K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector
@@ -64,6 +151,7 @@ struct VarNamedVector{
64151
Inactive entries are elements in `vals` that are not part of the value of any variable.
65152
They arise when a variable is set to a new value with a different dimension, in-place.
66153
Inactive entries always come after the last active entry for the given variable.
154+
See the extended help with `??VarNamedVector` for more details.
67155
"""
68156
num_inactive::OrderedDict{Int,Int}
69157

@@ -80,23 +168,40 @@ struct VarNamedVector{
80168
length(varnames) != length(transforms) ||
81169
length(varnames) != length(is_unconstrained) ||
82170
length(varnames) != length(varname_to_index)
83-
msg = "Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: $(length(varnames)), ranges: $(length(ranges)), transforms: $(length(transforms)), is_unconstrained: $(length(is_unconstrained)), varname_to_index: $(length(varname_to_index))."
171+
msg = (
172+
"Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: " *
173+
"$(length(varnames)), ranges: " *
174+
"$(length(ranges)), " *
175+
"transforms: $(length(transforms)), " *
176+
"is_unconstrained: $(length(is_unconstrained)), " *
177+
"varname_to_index: $(length(varname_to_index))."
178+
)
84179
throw(ArgumentError(msg))
85180
end
86181

87182
num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive))
88183
if num_vals != length(vals)
89-
msg = "The total number of elements in `vals` ($(length(vals))) does not match the sum of the lengths of the ranges and the number of inactive entries ($(num_vals))."
184+
msg = (
185+
"The total number of elements in `vals` ($(length(vals))) does not match " *
186+
"the sum of the lengths of the ranges and the number of inactive entries " *
187+
"($(num_vals))."
188+
)
90189
throw(ArgumentError(msg))
91190
end
92191

93-
if Set(values(varname_to_index)) != Set(1:length(varnames))
94-
msg = "The values of `varname_to_index` are not valid indices."
192+
if Set(values(varname_to_index)) != Set(axes(varnames, 1))
193+
msg = (
194+
"The set of values of `varname_to_index` is not the set of valid indices " *
195+
"for `varnames`."
196+
)
95197
throw(ArgumentError(msg))
96198
end
97199

98200
if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index)))
99-
msg = "The keys of `num_inactive` are not valid indices."
201+
msg = (
202+
"The keys of `num_inactive` are not a subset of the values of " *
203+
"`varname_to_index`."
204+
)
100205
throw(ArgumentError(msg))
101206
end
102207

@@ -107,7 +212,10 @@ struct VarNamedVector{
107212
for vn2 in keys(varname_to_index)
108213
vn1 === vn2 && continue
109214
if subsumes(vn1, vn2)
110-
msg = "Variables in a VarNamedVector should not subsume each other, but $vn1 subsumes $vn2."
215+
msg = (
216+
"Variables in a VarNamedVector should not subsume each other, " *
217+
"but $vn1 subsumes $vn2, i.e. $vn2 describes a subrange of $vn1."
218+
)
111219
throw(ArgumentError(msg))
112220
end
113221
end
@@ -241,14 +349,18 @@ end
241349
"""
242350
has_inactive(vnv::VarNamedVector)
243351
244-
Returns `true` if `vnv` has inactive ranges.
352+
Returns `true` if `vnv` has inactive entries.
353+
354+
See also: [`num_inactive`](@ref)
245355
"""
246356
has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive)
247357

248358
"""
249359
num_inactive(vnv::VarNamedVector)
250360
251361
Return the number of inactive entries in `vnv`.
362+
363+
See also: [`has_inactive`](@ref), [`num_allocated`](@ref)
252364
"""
253365
num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive))
254366

@@ -262,29 +374,33 @@ num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0)
262374

263375
"""
264376
num_allocated(vnv::VarNamedVector)
377+
num_allocated(vnv::VarNamedVector[, vn::VarName])
378+
num_allocated(vnv::VarNamedVector[, idx::Int])
265379
266-
Returns the number of allocated entries in `vnv`, both active and inactive.
267-
"""
268-
num_allocated(vnv::VarNamedVector) = length(vnv.vals)
380+
Return the number of allocated entries in `vnv`, both active and inactive.
269381
270-
"""
271-
num_allocated(vnv::VarNamedVector, vn::VarName)
382+
If either a `VarName` or an `Int` index is specified, only count entries allocated for that
383+
variable.
272384
273-
Returns the number of allocated entries for `vn` in `vnv`, both active and inactive.
385+
Allocated entries take up memory in `vnv.vals`, but, if inactive, may not currently hold any
386+
meaningful data. One can remove them with [`contiguify!`](@ref), but doing so may cause more
387+
memory allocations in the future if variables change dimension.
274388
"""
389+
num_allocated(vnv::VarNamedVector) = length(vnv.vals)
275390
num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn))
276391
function num_allocated(vnv::VarNamedVector, idx::Int)
277392
return length(getrange(vnv, idx)) + num_inactive(vnv, idx)
278393
end
279394

280395
# Basic array interface.
281396
Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals)
282-
Base.length(vnv::VarNamedVector) =
397+
function Base.length(vnv::VarNamedVector)
283398
if !has_inactive(vnv)
284-
length(vnv.vals)
399+
return length(vnv.vals)
285400
else
286-
sum(length, vnv.ranges)
401+
return sum(length, vnv.ranges)
287402
end
403+
end
288404
Base.size(vnv::VarNamedVector) = (length(vnv),)
289405
Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames)
290406

test/varinfo.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
111111
@test vi[vn] == 3 * r
112112
@test vi[SampleFromPrior()][1] == 3 * r
113113

114+
# TODO(mhauru) Implement these functions for SimpleVarInfo too.
115+
if vi isa VarInfo
116+
delete!(vi, vn)
117+
@test isempty(vi)
118+
vi = push!!(vi, vn, r, dist, gid)
119+
end
120+
114121
vi = empty!!(vi)
115122
@test isempty(vi)
116123
return push!!(vi, vn, r, dist, gid)

test/varnamedvector.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,10 @@ end
297297
# Should not be possible to `push!` existing varname.
298298
@test_throws ArgumentError push!(vnv, vn, val)
299299
else
300-
push!(vnv, vn, val)
300+
vnv_copy = deepcopy(vnv)
301+
push!(vnv_copy, vn, val)
302+
@test vnv_copy[vn] == val
303+
push!(vnv, (vn => val))
301304
@test vnv[vn] == val
302305
end
303306
end

0 commit comments

Comments
 (0)