@@ -22,8 +22,6 @@ See the documentation for each of these functions for more details.
2222"""
2323abstract type AbstractAccumulator end
2424
25- # TODO (mhauru) Add to the above docstring stuff about resets.
26-
2725"""
2826 accumulator_name(acc::AbstractAccumulator)
2927
@@ -88,17 +86,18 @@ See also: [`split`](@ref)
8886"""
8987function combine end
9088
89+ # TODO (mhauru) The existence of this function makes me sad. See comment in unflatten in
90+ # src/varinfo.jl.
9191"""
92- acc!! (acc::AbstractAccumulator, args... )
92+ convert_eltype (acc::AbstractAccumulator, ::Type{T} )
9393
94- Update `acc` with the values in `args`. Returns the updated `acc `.
94+ Convert `acc` to use element type `T `.
9595
96- What this means depends greatly on the type of `acc`. For example, for `LogPrior` `args`
97- would be just `logp`. The utility of this function is that one can call
98- `acc!!(varinfo::AbstractVarinfo, Val(accname), args...)`, and this call will be propagated
99- to a call on the particular accumulator.
96+ What "element type" means depends on the type of `acc`. By default this function does
97+ nothing. Accumulator types that need to hold differentiable values, such as dual numbers
98+ used by various AD backends, should implement a method for this function.
10099"""
101- function acc!! end
100+ convert_eltype ( acc:: AbstractAccumulator , :: Type ) = acc
102101
103102# END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE
104103
@@ -314,12 +313,12 @@ function Base.convert(::Type{NumProduce{T}}, acc::NumProduce) where {T}
314313 return NumProduce (convert (T, acc. num))
315314end
316315
316+ # TODO (mhauru)
317+ # We ignore the convert_eltype calls for NumProduce, by letting them fallback on
318+ # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
319+ # deal with dual number types of AD backends, which shouldn't concern NumProduce. This is
320+ # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
317321convert_eltype (acc:: LogPrior , :: Type{T} ) where {T} = LogPrior (convert (T, acc. logp))
318322function convert_eltype (acc:: LogLikelihood , :: Type{T} ) where {T}
319323 return LogLikelihood (convert (T, acc. logp))
320324end
321- # TODO (mhauru)
322- # We ignore the convert_eltype calls for NumProduce. This is because they are only used to
323- # deal with dual number types of AD backends, which shouldn't concern NumProduce. This is
324- # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
325- convert_eltype (acc:: NumProduce , :: Type ) = NumProduce (acc. num)
0 commit comments