Skip to content

Commit 5d9e934

Browse files
mhaurupenelopeysm
andauthored
Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator (#999)
* logjac accumulator * Fix tests * Fix a whole bunch of stuff * Fix final tests * Fix docs * Fix docs/doctests * Fix maths in LogJacobianAccumulator docstring * Twiddle with a comment * Add changelog * Simplify accs with LogProbAccumulator * Replace + with accumulate for LogProbAccs * Introduce merge and subset for accs * Improve acc tests * Fix docstring typo. Co-authored-by: Penelope Yong <[email protected]> * Fix merge --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent c6c0cbc commit 5d9e934

File tree

8 files changed

+332
-156
lines changed

8 files changed

+332
-156
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export AbstractVarInfo,
6666
setlogprior!!,
6767
setlogjac!!,
6868
setloglikelihood!!,
69+
acclogp,
6970
acclogp!!,
7071
acclogjac!!,
7172
acclogprior!!,

src/abstract_varinfo.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ Add `logp` to the value of the log of the prior probability in `vi`.
358358
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref).
359359
"""
360360
function acclogprior!!(vi::AbstractVarInfo, logp)
361-
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
361+
return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior))
362362
end
363363

364364
"""
@@ -369,9 +369,7 @@ Add `logjac` to the value of the log Jacobian in `vi`.
369369
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
370370
"""
371371
function acclogjac!!(vi::AbstractVarInfo, logjac)
372-
return map_accumulator!!(
373-
acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian)
374-
)
372+
return map_accumulator!!(acc -> acclogp(acc, logjac), vi, Val(:LogJacobian))
375373
end
376374

377375
"""
@@ -382,9 +380,7 @@ Add `logp` to the value of the log of the likelihood in `vi`.
382380
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref).
383381
"""
384382
function accloglikelihood!!(vi::AbstractVarInfo, logp)
385-
return map_accumulator!!(
386-
acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood)
387-
)
383+
return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood))
388384
end
389385

390386
"""

src/accumulators.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement:
3030
- `split(acc::T)`
3131
- `combine(acc::T, acc2::T)`
3232
33+
If two accumulators of the same type should be merged in some non-trivial way, other than
34+
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
35+
36+
If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
37+
do something other than copy the original accumulator, then
38+
`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`
39+
3340
See the documentation for each of these functions for more details.
3441
"""
3542
abstract type AbstractAccumulator end
@@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function.
113120
"""
114121
convert_eltype(::Type, acc::AbstractAccumulator) = acc
115122

123+
"""
124+
subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})
125+
126+
Return a new accumulator that only contains the information for the `VarName`s in `vns`.
127+
128+
By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
129+
"""
130+
subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc)
131+
132+
"""
133+
merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)
134+
135+
Merge two accumulators of the same type. Returns a new accumulator of the same type.
136+
137+
By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
138+
"""
139+
Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2)
140+
116141
"""
117142
AccumulatorTuple{N,T<:NamedTuple}
118143
@@ -158,6 +183,50 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
158183
return AccumulatorTuple(convert(T, accs.nt))
159184
end
160185

186+
"""
187+
subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
188+
189+
Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
190+
"""
191+
function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
192+
return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt))
193+
end
194+
195+
"""
196+
_joint_keys(nt1::NamedTuple, nt2::NamedTuple)
197+
198+
A helper function that returns three tuples of keys given two `NamedTuple`s:
199+
The keys only in `nt1`, only in `nt2`, and in both, and in that order.
200+
201+
Implemented as a generated function to enable constant propagation of the result in `merge`.
202+
"""
203+
@generated function _joint_keys(
204+
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
205+
) where {names1,names2}
206+
only_in_nt1 = tuple(setdiff(names1, names2)...)
207+
only_in_nt2 = tuple(setdiff(names2, names1)...)
208+
in_both = tuple(intersect(names1, names2)...)
209+
return :($only_in_nt1, $only_in_nt2, $in_both)
210+
end
211+
212+
"""
213+
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
214+
215+
Merge two `AccumulatorTuple`s.
216+
217+
For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
218+
accumulators themselves. Other accumulators are copied.
219+
"""
220+
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
221+
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
222+
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
223+
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
224+
accs_in_both = (
225+
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
226+
)
227+
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
228+
end
229+
161230
"""
162231
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
163232

0 commit comments

Comments
 (0)