Skip to content

Commit 19978ec

Browse files
mhaurutorfjeldegithub-actions[bot]sunxd3
authored
More work on VarNameVector (#637)
* Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Markus Hauru <[email protected]> * Type-stability tests are now correctly using `rand_prior_true` instead of `rand` * `getindex_internal` now calls `getindex` instead of `view`, as the latter can result in type-instability since transformed variables typically result in non-view even if input is a view * Removed seemingly unnecessary definition of `getindex_internal` * Fixed references to `newmetadata` which has been replaced by `replace_values` * Made implementation of `recombine` more explicit * Added docstrings for `untyped_varinfo` and `typed_varinfo` * Added TODO comment about implementing `view` for `VarInfo` * Fixed potential infinite recursion as suggested by @mhauru * added docstring to `from_vec_trnasform_for_size * Replaced references to `vectorize(dist, x)` with `tovec(x)` * Fixed docstring * Update src/extract_priors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump minor version since this is a breaking change * Apply suggestions from code review Co-authored-by: Markus Hauru <[email protected]> * Update src/varinfo.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Apply suggestions from code review * Apply suggestions from code review * Update src/extract_priors.jl Co-authored-by: Xianda Sun <[email protected]> * Added fix for product distributions of targets with changing support + tests * Addeed tests for product of distributions with dynamic support * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix typos, improve docstrings * Use Accessors rather than Setfield * Simplify group_by_symbol * Add short_varinfo_name(::VectorVarInfo) * Add tests for subset * Export VectorVarInfo * Tighter type bound for has_varnamevector * Add some VectorVarName methods * Add todo notes, remove dead code, fix a typo. * Bug fixes and small improvements * VarNameVector improvements * Improve generated_quantities and its tests * Improvement to VarNameVector * Fix a test to work with VectorVarName * Fix generated_quantities * Fix type stability issues * Various VarNameVector fixes and improvements * Bump version number * Improvements to generated_quantities * Code formatting * Code style * Add fallback implementation of findinds for VarNameVector * Rename VarNameVector to VarNamedVector * More renaming of VNV. Remove unused VarNamedVector.metadata field. * Rename FromVec to ReshapeTransform * Progress towards having VarNamedVector as storage for SimpleVarInfo * Fix unflatten(vnv::VarNamedVector, vals) * More work on SimpleVarInfo{VarNamedVector} * More tests for SimpleVarInfo{VarNamedVector} * More tests for SimpleVarInfo{VarNamedVector} * Respond to review feedback * Add float_type_with_fallback(::Type{Union{}}) * Move some VNV functions to the correct file * Fix push! for VNV * Rename VNV.is_transformed to VNV.is_unconstrained * Improve VNV docstring * Add VNV inner constructor checks * Reorganise parts of VNV code * Documentation and small fixes for VNV * Rename loosen_types!! and tighten_types, add docstrings and doctests * Rename VarNameVector to VarNamedVector in docs * Documentation and small fixes to VNV * Fix subset(::VarNamedVector, args...) for unconstrained variables. --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]>
1 parent 9b1014d commit 19978ec

23 files changed

+1835
-1107
lines changed

docs/src/internals/varinfo.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,21 +132,21 @@ Hence we obtain a "type-stable when possible"-representation by wrapping it in a
132132

133133
## Efficient storage and iteration
134134

135-
Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNameVector`](@ref):
135+
Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNamedVector`](@ref):
136136

137137
```@docs
138-
DynamicPPL.VarNameVector
138+
DynamicPPL.VarNamedVector
139139
```
140140

141-
In a [`VarNameVector{<:VarName,Vector{T}}`](@ref), we achieve the desirata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s.
141+
In a [`VarNamedVector{<:VarName,Vector{T}}`](@ref), we achieve the desirata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s.
142142

143143
This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields:
144144

145145
- `varnames::Vector{VarName}`: the `VarName`s in the order they appear in the `Vector{T}`.
146146
- `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`.
147147
- `transforms::Vector`: the transforms associated with each `VarName`.
148148

149-
Mutating functions, e.g. `setindex!(vnv::VarNameVector, val, vn::VarName)`, are then treated according to the following rules:
149+
Mutating functions, e.g. `setindex!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules:
150150

151151
1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc.
152152

@@ -156,7 +156,7 @@ Mutating functions, e.g. `setindex!(vnv::VarNameVector, val, vn::VarName)`, are
156156
2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field.
157157
3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`.
158158

159-
This means that `VarNameVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in.
159+
This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in.
160160

161161
For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example:
162162

@@ -195,7 +195,7 @@ DynamicPPL.contiguify!
195195
For example, one might encounter the following scenario:
196196

197197
```@example varinfo-design
198-
vnv = DynamicPPL.VarNameVector(@varname(x) => [true])
198+
vnv = DynamicPPL.VarNamedVector(@varname(x) => [true])
199199
println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))")
200200
201201
for i in 1:5
@@ -210,7 +210,7 @@ end
210210
We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage:
211211

212212
```@example varinfo-design
213-
vnv = DynamicPPL.VarNameVector(@varname(x) => [true])
213+
vnv = DynamicPPL.VarNamedVector(@varname(x) => [true])
214214
println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))")
215215
216216
for i in 1:5
@@ -225,13 +225,13 @@ for i in 1:5
225225
end
226226
```
227227

228-
This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNameVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous.
228+
This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous.
229229

230230
!!! note
231231

232-
Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing he `VarName`'s transformation with a `DynamicPPL.FromVec`.
232+
Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`.
233233

234-
Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNameVector` as the `metadata` field:
234+
Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field:
235235

236236
```@example varinfo-design
237237
# Type-unstable
@@ -287,23 +287,23 @@ DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x))
287287

288288
### Performance summary
289289

290-
In the end, we have the following "rough" performance characteristics for `VarNameVector`:
290+
In the end, we have the following "rough" performance characteristics for `VarNamedVector`:
291291

292-
| Method | Is blazingly fast? |
293-
|:---------------------------------------:|:--------------------------------------------------------------------------------------------:|
294-
| `getindex` | ${\color{green} \checkmark}$ |
295-
| `setindex!` | ${\color{green} \checkmark}$ |
296-
| `push!` | ${\color{green} \checkmark}$ |
297-
| `delete!` | ${\color{red} \times}$ |
298-
| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size |
299-
| `values_as(::VarNameVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise |
292+
| Method | Is blazingly fast? |
293+
|:----------------------------------------:|:--------------------------------------------------------------------------------------------:|
294+
| `getindex` | ${\color{green} \checkmark}$ |
295+
| `setindex!` | ${\color{green} \checkmark}$ |
296+
| `push!` | ${\color{green} \checkmark}$ |
297+
| `delete!` | ${\color{red} \times}$ |
298+
| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size |
299+
| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise |
300300

301301
## Other methods
302302

303303
```@docs
304-
DynamicPPL.replace_values(::VarNameVector, vals::AbstractVector)
304+
DynamicPPL.replace_values(::VarNamedVector, vals::AbstractVector)
305305
```
306306

307307
```@docs; canonical=false
308-
DynamicPPL.values_as(::VarNameVector)
308+
DynamicPPL.values_as(::VarNamedVector)
309309
```

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
1212
function _check_varname_indexing(c::MCMCChains.Chains)
1313
return DynamicPPL.supports_varname_indexing(c) ||
14-
error("Chains do not support indexing using $vn.")
14+
error("$(typeof(c)) do not support indexing using varnmes.")
1515
end
1616

1717
# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
@@ -41,21 +41,152 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4141
return keys(c.info.varname_to_symbol)
4242
end
4343

44+
"""
45+
generated_quantities(model::Model, chain::MCMCChains.Chains)
46+
47+
Execute `model` for each of the samples in `chain` and return an array of the values
48+
returned by the `model` for each sample.
49+
50+
# Examples
51+
## General
52+
Often you might have additional quantities computed inside the model that you want to
53+
inspect, e.g.
54+
```julia
55+
@model function demo(x)
56+
# sample and observe
57+
θ ~ Prior()
58+
x ~ Likelihood()
59+
return interesting_quantity(θ, x)
60+
end
61+
m = demo(data)
62+
chain = sample(m, alg, n)
63+
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
64+
# from the posterior/`chain`:
65+
generated_quantities(m, chain) # <= results in a `Vector` of returned values
66+
# from `interesting_quantity(θ, x)`
67+
```
68+
## Concrete (and simple)
69+
```julia
70+
julia> using DynamicPPL, Turing
71+
72+
julia> @model function demo(xs)
73+
s ~ InverseGamma(2, 3)
74+
m_shifted ~ Normal(10, √s)
75+
m = m_shifted - 10
76+
77+
for i in eachindex(xs)
78+
xs[i] ~ Normal(m, √s)
79+
end
80+
81+
return (m, )
82+
end
83+
demo (generic function with 1 method)
84+
85+
julia> model = demo(randn(10));
86+
87+
julia> chain = sample(model, MH(), 10);
88+
89+
julia> generated_quantities(model, chain)
90+
10×1 Array{Tuple{Float64},2}:
91+
(2.1964758025119338,)
92+
(2.1964758025119338,)
93+
(0.09270081916291417,)
94+
(0.09270081916291417,)
95+
(0.09270081916291417,)
96+
(0.09270081916291417,)
97+
(0.09270081916291417,)
98+
(0.043088571494005024,)
99+
(-0.16489786710222099,)
100+
(-0.16489786710222099,)
101+
```
102+
"""
44103
function DynamicPPL.generated_quantities(
45104
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
46105
)
47106
chain = MCMCChains.get_sections(chain_full, :parameters)
48107
varinfo = DynamicPPL.VarInfo(model)
49108
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
50109
return map(iters) do (sample_idx, chain_idx)
51-
# Update the varinfo with the current sample and make variables not present in `chain`
52-
# to be sampled.
53-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
110+
if DynamicPPL.supports_varname_indexing(chain)
111+
varname_pairs = _varname_pairs_with_varname_indexing(
112+
chain, varinfo, sample_idx, chain_idx
113+
)
114+
else
115+
varname_pairs = _varname_pairs_without_varname_indexing(
116+
chain, varinfo, sample_idx, chain_idx
117+
)
118+
end
119+
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
120+
return fixed_model()
121+
end
122+
end
123+
124+
"""
125+
_varname_pairs_with_varname_indexing(
126+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
127+
)
54128
55-
# TODO: Some of the variables can be a view into the `varinfo`, so we need to
56-
# `deepcopy` the `varinfo` before passing it to `model`.
57-
model(deepcopy(varinfo))
129+
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
130+
from the chain.
131+
132+
This implementation assumes `chain` can be indexed using variable names, and is the
133+
preffered implementation.
134+
"""
135+
function _varname_pairs_with_varname_indexing(
136+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
137+
)
138+
vns = DynamicPPL.varnames(chain)
139+
vn_parents = Iterators.map(vns) do vn
140+
# The call nested_setindex_maybe! is used to handle cases where vn is not
141+
# the variable name used in the model, but rather subsumed by one. Except
142+
# for the subsumption part, this could be
143+
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
144+
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
145+
DynamicPPL.nested_setindex_maybe!(
146+
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
147+
)
58148
end
149+
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
150+
vn_parent => varinfo[vn_parent]
151+
end
152+
return varname_pairs
153+
end
154+
155+
"""
156+
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
157+
158+
The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
159+
won't catch all cases. We should get rid of this if we can.
160+
"""
161+
# TODO(mhauru) See docstring above.
162+
function _vcat_subsumed_values(vn_string, values, key_strings)
163+
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
164+
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
165+
end
166+
167+
"""
168+
_varname_pairs_without_varname_indexing(
169+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
170+
)
171+
172+
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
173+
from the chain.
174+
175+
This implementation does not assume that `chain` can be indexed using variable names. It is
176+
thus not guaranteed to work in cases where the variable names have complex subsumption
177+
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
178+
"""
179+
function _varname_pairs_without_varname_indexing(
180+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
181+
)
182+
values = chain.value[sample_idx, :, chain_idx]
183+
keys = Base.keys(chain)
184+
keys_strings = map(string, keys)
185+
varname_pairs = [
186+
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
187+
vn in Base.keys(varinfo)
188+
]
189+
return varname_pairs
59190
end
60191

61192
end

ext/DynamicPPLReverseDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ else
99
end
1010

1111
function LogDensityProblemsAD.ADgradient(
12-
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
13-
) where {Tcompile}
12+
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
13+
)
1414
return LogDensityProblemsAD.ADgradient(
1515
Val(:ReverseDiff),
1616
ℓ;
17-
compile=Val(Tcompile),
17+
compile=Val(ad.compile),
1818
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
1919
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
2020
# `zero(D)` will return 0 when D is Real.

src/DynamicPPL.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ export AbstractVarInfo,
4545
VarInfo,
4646
UntypedVarInfo,
4747
TypedVarInfo,
48+
VectorVarInfo,
4849
SimpleVarInfo,
49-
VarNameVector,
50+
VarNamedVector,
5051
push!!,
5152
empty!!,
5253
subset,
@@ -176,7 +177,7 @@ include("sampler.jl")
176177
include("varname.jl")
177178
include("distribution_wrappers.jl")
178179
include("contexts.jl")
179-
include("varnamevector.jl")
180+
include("varnamedvector.jl")
180181
include("abstract_varinfo.jl")
181182
include("threadsafe.jl")
182183
include("varinfo.jl")

src/abstract_varinfo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
295295
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
296296
297297
julia> # For the sake of brevity, let's just check the type.
298-
md = values_as(vi); md.s isa DynamicPPL.Metadata
298+
md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector}
299299
true
300300
301301
julia> values_as(vi, NamedTuple)
@@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
321321
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
322322
323323
julia> # For the sake of brevity, let's just check the type.
324-
values_as(vi) isa DynamicPPL.Metadata
324+
values_as(vi) isa Union{DynamicPPL.Metadata, Vector}
325325
true
326326
327327
julia> values_as(vi, NamedTuple)
@@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
349349
This should generally not be called explicitly, as it's only used in
350350
[`matchingvalue`](@ref) to determine the default type to use in place of
351351
type-parameters passed to the model.
352-
352+
353353
This method is considered legacy, and is likely to be deprecated in the future.
354354
"""
355355
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})

src/context_implementations.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ function assume(
240240
if haskey(vi, vn)
241241
# Always overwrite the parameters with new ones for `SampleFromUniform`.
242242
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
243-
unset_flag!(vi, vn, "del")
243+
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
244+
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
245+
# that's okay.
246+
unset_flag!(vi, vn, "del", true)
244247
r = init(rng, dist, sampler)
245248
f = to_maybe_linked_internal_transform(vi, vn, dist)
246249
BangBang.setindex!!(vi, f(r), vn)
@@ -516,7 +519,10 @@ function get_and_set_val!(
516519
if haskey(vi, vns[1])
517520
# Always overwrite the parameters with new ones for `SampleFromUniform`.
518521
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
519-
unset_flag!(vi, vns[1], "del")
522+
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
523+
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
524+
# that's okay.
525+
unset_flag!(vi, vns[1], "del", true)
520526
r = init(rng, dist, spl, n)
521527
for i in 1:n
522528
vn = vns[i]
@@ -554,7 +560,10 @@ function get_and_set_val!(
554560
if haskey(vi, vns[1])
555561
# Always overwrite the parameters with new ones for `SampleFromUniform`.
556562
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
557-
unset_flag!(vi, vns[1], "del")
563+
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
564+
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
565+
# that's okay.
566+
unset_flag!(vi, vns[1], "del", true)
558567
f = (vn, dist) -> init(rng, dist, spl)
559568
r = f.(vns, dists)
560569
for i in eachindex(vns)

0 commit comments

Comments
 (0)