Skip to content

Commit 221e797

Browse files
committed
Improve accumulator docs
1 parent d52feec commit 221e797

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/accumulators.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ See the documentation for each of these functions for more details.
2222
"""
2323
abstract type AbstractAccumulator end
2424

25-
# TODO(mhauru) Add to the above docstring stuff about resets.
26-
2725
"""
2826
accumulator_name(acc::AbstractAccumulator)
2927
@@ -88,17 +86,18 @@ See also: [`split`](@ref)
8886
"""
8987
function combine end
9088

89+
# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in
90+
# src/varinfo.jl.
9191
"""
92-
acc!!(acc::AbstractAccumulator, args...)
92+
convert_eltype(acc::AbstractAccumulator, ::Type{T})
9393
94-
Update `acc` with the values in `args`. Returns the updated `acc`.
94+
Convert `acc` to use element type `T`.
9595
96-
What this means depends greatly on the type of `acc`. For example, for `LogPrior` `args`
97-
would be just `logp`. The utility of this function is that one can call
98-
`acc!!(varinfo::AbstractVarinfo, Val(accname), args...)`, and this call will be propagated
99-
to a call on the particular accumulator.
96+
What "element type" means depends on the type of `acc`. By default this function does
97+
nothing. Accumulator types that need to hold differentiable values, such as dual numbers
98+
used by various AD backends, should implement a method for this function.
10099
"""
101-
function acc!! end
100+
convert_eltype(acc::AbstractAccumulator, ::Type) = acc
102101

103102
# END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE
104103

@@ -314,12 +313,12 @@ function Base.convert(::Type{NumProduce{T}}, acc::NumProduce) where {T}
314313
return NumProduce(convert(T, acc.num))
315314
end
316315

316+
# TODO(mhauru)
317+
# We ignore the convert_eltype calls for NumProduce, by letting them fallback on
318+
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
319+
# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is
320+
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
317321
convert_eltype(acc::LogPrior, ::Type{T}) where {T} = LogPrior(convert(T, acc.logp))
318322
function convert_eltype(acc::LogLikelihood, ::Type{T}) where {T}
319323
return LogLikelihood(convert(T, acc.logp))
320324
end
321-
# TODO(mhauru)
322-
# We ignore the convert_eltype calls for NumProduce. This is because they are only used to
323-
# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is
324-
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
325-
convert_eltype(acc::NumProduce, ::Type) = NumProduce(acc.num)

0 commit comments

Comments
 (0)