Skip to content

Commit efd9da3

Browse files
torfjeldedevmotion
andauthored
subset and merge for VarInfo (clean version) (#544)
* added `subset` which can extract a subset of the varinfo * added testing of `subset` for `VarInfo` * formatting * added implementation of `merge` for `VarInfo` and tests for it * more tests * formatting * improved merge_metadata for NamedTuple inputs * added proper handling of the `vals` in `subset` * added docs for `subset` and `merge` * added `subset` and `merge` to documentation * formatting * made merge and subset part of the AbstractVarInfo interface * added implementations `subset` and `merge` for `SimpleVarInfo` * follow standard merge semantics where the right one takes precedence * added proper testing of merge and subset for SimpleVarInfo too * forgotten inclusion in previous commit * Update src/simple_varinfo.jl Co-authored-by: David Widmann <[email protected]> * remove two-argument impl of merge * formatting * forgot to add more formatting * removed 2-arg version of merge for abstract varinfo in favour of 3-arg version * allow inclusion of threadsafe varinfo in setup_varinfos * more tests for thread safe varinfo * bugfixes for link and invlink methods when using thread safe varinfo * attempt at fixing docs * fixed missing test coverage * formatting --------- Co-authored-by: David Widmann <[email protected]>
1 parent 927799f commit efd9da3

File tree

9 files changed

+716
-12
lines changed

9 files changed

+716
-12
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ DynamicPPL.reconstruct
255255
#### Utils
256256

257257
```@docs
258+
Base.merge(::AbstractVarInfo)
259+
DynamicPPL.subset
258260
DynamicPPL.unflatten
259261
DynamicPPL.tonamedtuple
260262
DynamicPPL.varname_leaves

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export AbstractVarInfo,
4848
SimpleVarInfo,
4949
push!!,
5050
empty!!,
51+
subset,
5152
getlogp,
5253
setlogp!!,
5354
acclogp!!,

src/abstract_varinfo.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation
5353
bijector::F
5454
end
5555

56+
"""
57+
merge_transformations(transformation_left, transformation_right)
58+
59+
Merge two transformations.
60+
61+
The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref).
62+
"""
63+
function merge_transformations(::NoTransformation, ::NoTransformation)
64+
return NoTransformation()
65+
end
66+
function merge_transformations(::DynamicTransformation, ::DynamicTransformation)
67+
return DynamicTransformation()
68+
end
69+
function merge_transformations(left::StaticTransformation, right::StaticTransformation)
70+
return StaticTransformation(merge_bijectors(left.bijector, right.bijector))
71+
end
72+
73+
function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform)
74+
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs))
75+
end
76+
5677
"""
5778
default_transformation(model::Model[, vi::AbstractVarInfo])
5879
@@ -337,6 +358,146 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
337358
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
338359
end
339360

361+
# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
362+
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
363+
# might result in a `Vector{Any}`.
364+
"""
365+
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
366+
367+
Subset a `varinfo` to only contain the variables `vns`.
368+
369+
!!! warning
370+
The ordering of the variables in the resulting `varinfo` is _not_
371+
guaranteed to follow the ordering of the variables in `varinfo`.
372+
Hence care must be taken, in particular when used in conjunction with
373+
other methods which uses the vector-representation of the `varinfo`,
374+
e.g. `getindex(varinfo, sampler)`.
375+
376+
# Examples
377+
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
378+
julia> @model function demo()
379+
s ~ InverseGamma(2, 3)
380+
m ~ Normal(0, sqrt(s))
381+
x = Vector{Float64}(undef, 2)
382+
x[1] ~ Normal(m, sqrt(s))
383+
x[2] ~ Normal(m, sqrt(s))
384+
end
385+
demo (generic function with 2 methods)
386+
387+
julia> model = demo();
388+
389+
julia> varinfo = VarInfo(model);
390+
391+
julia> keys(varinfo)
392+
4-element Vector{VarName}:
393+
s
394+
m
395+
x[1]
396+
x[2]
397+
398+
julia> for (i, vn) in enumerate(keys(varinfo))
399+
varinfo[vn] = i
400+
end
401+
402+
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
403+
4-element Vector{Float64}:
404+
1.0
405+
2.0
406+
3.0
407+
4.0
408+
409+
julia> # Extract one with only `m`.
410+
varinfo_subset1 = subset(varinfo, [@varname(m),]);
411+
412+
413+
julia> keys(varinfo_subset1)
414+
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
415+
m
416+
417+
julia> varinfo_subset1[@varname(m)]
418+
2.0
419+
420+
julia> # Extract one with both `s` and `x[2]`.
421+
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);
422+
423+
julia> keys(varinfo_subset2)
424+
2-element Vector{VarName}:
425+
s
426+
x[2]
427+
428+
julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
429+
2-element Vector{Float64}:
430+
1.0
431+
4.0
432+
```
433+
434+
`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref)
435+
436+
```jldoctest varinfo-subset
437+
julia> # Merge the two.
438+
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);
439+
440+
julia> keys(varinfo_subset_merged)
441+
3-element Vector{VarName}:
442+
m
443+
s
444+
x[2]
445+
446+
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
447+
3-element Vector{Float64}:
448+
1.0
449+
2.0
450+
4.0
451+
452+
julia> # Merge the two with the original.
453+
varinfo_merged = merge(varinfo, varinfo_subset_merged);
454+
455+
julia> keys(varinfo_merged)
456+
4-element Vector{VarName}:
457+
s
458+
m
459+
x[1]
460+
x[2]
461+
462+
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
463+
4-element Vector{Float64}:
464+
1.0
465+
2.0
466+
3.0
467+
4.0
468+
```
469+
470+
# Notes
471+
472+
## Type-stability
473+
474+
!!! warning
475+
This function is only type-stable when `vns` contains only varnames
476+
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
477+
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
478+
"""
479+
function subset end
480+
481+
"""
482+
merge(varinfo, other_varinfos...)
483+
484+
Merge varinfos into one, giving precedence to the right-most varinfo when sensible.
485+
486+
This is particularly useful when combined with [`subset(varinfo, vns)`](@ref).
487+
488+
See docstring of [`subset(varinfo, vns)`](@ref) for examples.
489+
"""
490+
Base.merge(varinfo::AbstractVarInfo) = varinfo
491+
# Define 3-argument version so 2-argument version will error if not implemented.
492+
function Base.merge(
493+
varinfo1::AbstractVarInfo,
494+
varinfo2::AbstractVarInfo,
495+
varinfo3::AbstractVarInfo,
496+
varinfo_others::AbstractVarInfo...,
497+
)
498+
return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...)
499+
end
500+
340501
# Transformations
341502
"""
342503
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])

src/simple_varinfo.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,51 @@ function Base.eltype(
419419
return V
420420
end
421421

422+
# `subset`
423+
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
424+
return Setfield.@set varinfo.values = _subset(varinfo.values, vns)
425+
end
426+
427+
function _subset(x::AbstractDict, vns)
428+
# NOTE: This requires `vns` to be explicitly present in `x`.
429+
if any(!Base.Fix1(haskey, x), vns)
430+
throw(
431+
ArgumentError(
432+
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
433+
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
434+
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.",
435+
),
436+
)
437+
end
438+
C = ConstructionBase.constructorof(typeof(x))
439+
return C(vn => x[vn] for vn in vns)
440+
end
441+
442+
function _subset(x::NamedTuple, vns)
443+
# NOTE: Here we can only handle `vns` that contain the `IdentityLens`.
444+
if any(Base.Fix1(!==, Setfield.IdentityLens()) getlens, vns)
445+
throw(
446+
ArgumentError(
447+
"Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " *
448+
"For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.",
449+
),
450+
)
451+
end
452+
453+
syms = map(getsym, vns)
454+
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms)))
455+
end
456+
457+
# `merge`
458+
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
459+
values = merge(varinfo_left.values, varinfo_right.values)
460+
logp = getlogp(varinfo_right)
461+
transformation = merge_transformations(
462+
varinfo_left.transformation, varinfo_right.transformation
463+
)
464+
return SimpleVarInfo(values, logp, transformation)
465+
end
466+
422467
# Context implementations
423468
# NOTE: Evaluations, i.e. those without `rng` are shared with other
424469
# implementations of `AbstractVarInfo`.

src/test_utils.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal
3737
end
3838

3939
"""
40-
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
40+
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
4141
4242
Return a tuple of instances for different implementations of `AbstractVarInfo` with
4343
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
44+
45+
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
46+
of the varinfo instances.
4447
"""
45-
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
48+
function setup_varinfos(
49+
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
50+
)
4651
# VarInfo
4752
vi_untyped = VarInfo()
4853
model(vi_untyped)
@@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
5661
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))
5762

5863
lp = getlogp(vi_typed)
59-
return map((
64+
varinfos = map((
6065
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
6166
)) do vi
6267
# Set them all to the same values.
6368
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
6469
end
70+
71+
if include_threadsafe
72+
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo deepcopy, varinfos)...)
73+
end
74+
75+
return varinfos
6576
end
6677

6778
"""

src/threadsafe.jl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,25 +84,56 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl
8484
function link!!(
8585
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
8686
)
87-
return link!!(t, vi.varinfo, spl, model)
87+
return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
8888
end
8989

9090
function invlink!!(
9191
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
9292
)
93-
return invlink!!(t, vi.varinfo, spl, model)
93+
return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
9494
end
9595

9696
function link(
9797
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
9898
)
99-
return link(t, vi.varinfo, spl, model)
99+
return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model)
100100
end
101101

102102
function invlink(
103103
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
104104
)
105-
return invlink(t, vi.varinfo, spl, model)
105+
return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
106+
end
107+
108+
# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
109+
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
110+
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
111+
# to define `getlogp(vi)`.
112+
function link!!(
113+
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
114+
)
115+
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
116+
end
117+
118+
function invlink!!(
119+
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
120+
)
121+
return settrans!!(
122+
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
123+
NoTransformation(),
124+
)
125+
end
126+
127+
function link(
128+
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
129+
)
130+
return link!!(t, deepcopy(vi), spl, model)
131+
end
132+
133+
function invlink(
134+
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
135+
)
136+
return invlink!!(t, deepcopy(vi), spl, model)
106137
end
107138

108139
function maybe_invlink_before_eval!!(
@@ -192,3 +223,20 @@ istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
192223
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
193224

194225
getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
226+
227+
function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
228+
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x)
229+
end
230+
function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector)
231+
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x)
232+
end
233+
234+
function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
235+
return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns)
236+
end
237+
238+
function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo)
239+
return Setfield.@set varinfo_left.varinfo = merge(
240+
varinfo_left.varinfo, varinfo_right.varinfo
241+
)
242+
end

0 commit comments

Comments
 (0)