Skip to content

Commit bc51d62

Browse files
committed
Introduce merge and subset for accs
1 parent f983da5 commit bc51d62

File tree

5 files changed

+118
-4
lines changed

5 files changed

+118
-4
lines changed

src/accumulators.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement:
3030
- `split(acc::T)`
3131
- `combine(acc::T, acc2::T)`
3232
33+
If two accumulators of the same type should be merged in some non-trivial way, other than
34+
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
35+
36+
If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
37+
do something other than copy the original accumulator, then
38+
`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`
39+
3340
See the documentation for each of these functions for more details.
3441
"""
3542
abstract type AbstractAccumulator end
@@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function.
113120
"""
114121
convert_eltype(::Type, acc::AbstractAccumulator) = acc
115122

123+
"""
124+
subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})
125+
126+
Return a new accumulator that only contains the information for the `VarName`s in `vns`.
127+
128+
By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
129+
"""
130+
subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc)
131+
132+
"""
133+
merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)
134+
135+
Merge two accumulators of the same type. Returns a new accumulator of the same type.
136+
137+
By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
138+
"""
139+
Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2)
140+
116141
"""
117142
AccumulatorTuple{N,T<:NamedTuple}
118143
@@ -158,6 +183,52 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
158183
return AccumulatorTuple(convert(T, accs.nt))
159184
end
160185

186+
"""
187+
subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
188+
189+
Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
190+
"""
191+
function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
192+
return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt))
193+
end
194+
195+
"""
196+
_joint_keys(nt1::NamedTuple, nt2::NamedTuple)
197+
198+
A helper function that returns three tuples of keys given two `NamedTuple`s:
199+
The keys only in `nt1`, only in `nt2`, and in both, and in that order.
200+
201+
Implemented as a generated function to enabled constant propagation of the result in `merge`.
202+
"""
203+
@generated function _joint_keys(
204+
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
205+
) where {names1,names2}
206+
set_names1 = Set(names1)
207+
set_names2 = Set(names2)
208+
only_in_nt1 = tuple(setdiff(set_names1, set_names2)...)
209+
only_in_nt2 = tuple(setdiff(set_names2, set_names1)...)
210+
in_both = tuple(intersect(set_names1, set_names2)...)
211+
return :($only_in_nt1, $only_in_nt2, $in_both)
212+
end
213+
214+
"""
215+
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
216+
217+
Merge two `AccumulatorTuple`s.
218+
219+
For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
220+
accumulators themselves. Other accumulators are copied.
221+
"""
222+
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
223+
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
224+
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
225+
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
226+
accs_in_both = (
227+
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
228+
)
229+
return AccumulatorTuple(accs_in_at1..., accs_in_at2..., accs_in_both...)
230+
end
231+
161232
"""
162233
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
163234

src/default_accumulators.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,19 @@ function default_accumulators(
266266
VariableOrderAccumulator{IntT}(),
267267
)
268268
end
269+
270+
function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName})
271+
order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order)
272+
return VariableOrderAccumulator(acc.num_produce, order)
273+
end
274+
275+
"""
276+
merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
277+
278+
Merge two `VariableOrderAccumulator` instances.
279+
280+
The `num_produce` field of the return value is the `num_produce` of `acc2`.
281+
"""
282+
function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
283+
return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order))
284+
end

src/simple_varinfo.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,9 @@ Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V
417417

418418
# `subset`
419419
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
420-
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
420+
return SimpleVarInfo(
421+
_subset(varinfo.values, vns), subset(getaccs(varinfo), vns), varinfo.transformation
422+
)
421423
end
422424

423425
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
@@ -454,7 +456,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns)
454456
# `merge`
455457
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
456458
values = merge(varinfo_left.values, varinfo_right.values)
457-
accs = copy(getaccs(varinfo_right))
459+
accs = merge(getaccs(varinfo_left), getaccs(varinfo_right))
458460
transformation = merge_transformations(
459461
varinfo_left.transformation, varinfo_right.transformation
460462
)

src/varinfo.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ end
447447

448448
function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
449449
metadata = subset(varinfo.metadata, vns)
450-
return VarInfo(metadata, copy(varinfo.accs))
450+
return VarInfo(metadata, subset(getaccs(varinfo), vns))
451451
end
452452

453453
function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
@@ -528,7 +528,8 @@ end
528528

529529
function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
530530
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
531-
return VarInfo(metadata, copy(varinfo_right.accs))
531+
accs = merge(getaccs(varinfo_left), getaccs(varinfo_right))
532+
return VarInfo(metadata, accs)
532533
end
533534

534535
function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector)

test/accumulators.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,30 @@ using DynamicPPL:
108108
@test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) ==
109109
VariableOrderAccumulator(2)
110110
end
111+
112+
@testset "merge and subset" begin
113+
@test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) ==
114+
LogPriorAccumulator(3.0)
115+
@test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) ==
116+
LogJacobianAccumulator(3.0)
117+
@test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) ==
118+
LogLikelihoodAccumulator(3.0)
119+
120+
@test merge(
121+
VariableOrderAccumulator(1, Dict{VarName,Int}()),
122+
VariableOrderAccumulator(2, Dict{VarName,Int}()),
123+
) == VariableOrderAccumulator(2, Dict{VarName,Int}())
124+
@test merge(
125+
VariableOrderAccumulator(
126+
2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2))
127+
),
128+
VariableOrderAccumulator(
129+
1, Dict{VarName,Int}((@varname(a) => 2, @varname(c) => 3))
130+
),
131+
) == VariableOrderAccumulator(
132+
1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3))
133+
)
134+
end
111135
end
112136

113137
@testset "accumulator tuples" begin

0 commit comments

Comments
 (0)