Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export AbstractVarInfo,
setlogprior!!,
setlogjac!!,
setloglikelihood!!,
acclogp,
acclogp!!,
acclogjac!!,
acclogprior!!,
Expand Down
10 changes: 3 additions & 7 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ Add `logp` to the value of the log of the prior probability in `vi`.
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref).
"""
function acclogprior!!(vi::AbstractVarInfo, logp)
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior))
end

"""
Expand All @@ -369,9 +369,7 @@ Add `logjac` to the value of the log Jacobian in `vi`.
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
"""
function acclogjac!!(vi::AbstractVarInfo, logjac)
return map_accumulator!!(
acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian)
)
return map_accumulator!!(acc -> acclogp(acc, logjac), vi, Val(:LogJacobian))
end

"""
Expand All @@ -382,9 +380,7 @@ Add `logp` to the value of the log of the likelihood in `vi`.
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref).
"""
function accloglikelihood!!(vi::AbstractVarInfo, logp)
return map_accumulator!!(
acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood)
)
return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood))
end

"""
Expand Down
69 changes: 69 additions & 0 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement:
- `split(acc::T)`
- `combine(acc::T, acc2::T)`

If two accumulators of the same type should be merged in some non-trivial way, other than
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
Comment on lines +33 to +34
Copy link
Member

@penelopeysm penelopeysm Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What's the difference between combine and merge?

  2. I guess this follows on from our conversation on Slack. but I wonder if there a way to restrict the call on subset / merge to only the accumulators we care about? basically is it possible for us to circumvent the need to call subset on the entire AccumulatorTuple and thus avoid including slightly weird implementations of subset on the logp accumulators? Mainly inspired by how you avoided implementing these for e.g. PLDAcc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine and merge probably do the same thing for most accumulators, but I'm not sure they have to for all accumulators. combine has the restriction that combine(acc, split(acc)) == acc, whereas merge could do anything that is desirable on a call to merge.

I don't know how to specify a subset of accumulators to call merge/subset on in a way that is any simpler than this implementation. Even before this PR, we used to call copy on accs in merge and subset. Now that has just been shifted to be a method of AbstractAccumulator that is the fallback for what to do when a subtype doesn't specify anything else.


If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
do something other than copy the original accumulator, then
`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`

See the documentation for each of these functions for more details.
"""
abstract type AbstractAccumulator end
Expand Down Expand Up @@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function.
"""
convert_eltype(::Type, acc::AbstractAccumulator) = acc

"""
subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})

Return a new accumulator that only contains the information for the `VarName`s in `vns`.

By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
"""
subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc)

"""
merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)

Merge two accumulators of the same type. Returns a new accumulator of the same type.

By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
"""
Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2)

"""
AccumulatorTuple{N,T<:NamedTuple}

Expand Down Expand Up @@ -158,6 +183,50 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
return AccumulatorTuple(convert(T, accs.nt))
end

"""
subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})

Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
"""
function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt))
end

"""
_joint_keys(nt1::NamedTuple, nt2::NamedTuple)

A helper function that returns three tuples of keys given two `NamedTuple`s:
The keys only in `nt1`, only in `nt2`, and in both, and in that order.

Implemented as a generated function to enable constant propagation of the result in `merge`.
"""
@generated function _joint_keys(
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
) where {names1,names2}
only_in_nt1 = tuple(setdiff(names1, names2)...)
only_in_nt2 = tuple(setdiff(names2, names1)...)
in_both = tuple(intersect(names1, names2)...)
return :($only_in_nt1, $only_in_nt2, $in_both)
end

"""
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)

Merge two `AccumulatorTuple`s.

For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
accumulators themselves. Other accumulators are copied.
"""
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
accs_in_both = (
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
)
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
end
Comment on lines +203 to +228
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, and this implementation causes only one allocation of 32 bits, due to a new AccumulatorTuple being created.


"""
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)

Expand Down
Loading
Loading