Skip to content

Commit 7ad9450

Browse files
committed
Move default accumulators to default_accumulators.jl
1 parent c4ee4ec commit 7ad9450

File tree

3 files changed

+145
-149
lines changed

3 files changed

+145
-149
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ include("distribution_wrappers.jl")
179179
include("contexts.jl")
180180
include("varnamedvector.jl")
181181
include("accumulators.jl")
182+
include("default_accumulators.jl")
182183
include("abstract_varinfo.jl")
183184
include("threadsafe.jl")
184185
include("varinfo.jl")

src/accumulators.jl

Lines changed: 0 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ used by various AD backends, should implement a method for this function.
9999
"""
100100
convert_eltype(::Type, acc::AbstractAccumulator) = acc
101101

102-
# END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE
103-
104102
"""
105103
AccumulatorTuple{N,T<:NamedTuple}
106104
@@ -189,150 +187,3 @@ function map_accumulator(
189187
new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,)))
190188
return AccumulatorTuple(new_nt)
191189
end
192-
193-
# END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS
194-
195-
"""
196-
LogPriorAccumulator{T<:Real} <: AbstractAccumulator
197-
198-
An accumulator that tracks the cumulative log prior during model execution.
199-
200-
# Fields
201-
$(TYPEDFIELDS)
202-
"""
203-
struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator
204-
"the scalar log prior value"
205-
logp::T
206-
end
207-
208-
"""
209-
LogPriorAccumulator{T}()
210-
211-
Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero.
212-
"""
213-
LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T))
214-
LogPriorAccumulator() = LogPriorAccumulator{LogProbType}()
215-
216-
"""
217-
LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
218-
219-
An accumulator that tracks the cumulative log likelihood during model execution.
220-
221-
# Fields
222-
$(TYPEDFIELDS)
223-
"""
224-
struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
225-
"the scalar log likelihood value"
226-
logp::T
227-
end
228-
229-
"""
230-
LogLikelihoodAccumulator{T}()
231-
232-
Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero.
233-
"""
234-
LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T))
235-
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
236-
237-
"""
238-
NumProduceAccumulator{T} <: AbstractAccumulator
239-
240-
An accumulator that tracks the number of observations during model execution.
241-
242-
# Fields
243-
$(TYPEDFIELDS)
244-
"""
245-
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
246-
"the number of observations"
247-
num::T
248-
end
249-
250-
"""
251-
NumProduceAccumulator{T<:Integer}()
252-
253-
Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
254-
"""
255-
NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
256-
NumProduceAccumulator() = NumProduceAccumulator{Int}()
257-
258-
function Base.show(io::IO, acc::LogPriorAccumulator)
259-
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
260-
end
261-
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
262-
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
263-
end
264-
function Base.show(io::IO, acc::NumProduceAccumulator)
265-
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
266-
end
267-
268-
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
269-
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
270-
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
271-
272-
split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
273-
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
274-
split(acc::NumProduceAccumulator) = acc
275-
276-
function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
277-
return LogPriorAccumulator(acc.logp + acc2.logp)
278-
end
279-
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
280-
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
281-
end
282-
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
283-
return NumProduceAccumulator(max(acc.num, acc2.num))
284-
end
285-
286-
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
287-
return LogPriorAccumulator(acc1.logp + acc2.logp)
288-
end
289-
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
290-
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
291-
end
292-
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
293-
294-
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
295-
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
296-
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))
297-
298-
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
299-
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
300-
end
301-
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
302-
303-
accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc
304-
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
305-
# Note that it's important to use the loglikelihood function here, not logpdf, because
306-
# they handle vectors differently:
307-
# https://github.com/JuliaStats/Distributions.jl/issues/1972
308-
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
309-
end
310-
311-
accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
312-
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
313-
314-
function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
315-
return LogPriorAccumulator(convert(T, acc.logp))
316-
end
317-
function Base.convert(
318-
::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator
319-
) where {T}
320-
return LogLikelihoodAccumulator(convert(T, acc.logp))
321-
end
322-
function Base.convert(
323-
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
324-
) where {T}
325-
return NumProduceAccumulator(convert(T, acc.num))
326-
end
327-
328-
# TODO(mhauru)
329-
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
330-
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
331-
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
332-
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
333-
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
334-
return LogPriorAccumulator(convert(T, acc.logp))
335-
end
336-
function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T}
337-
return LogLikelihoodAccumulator(convert(T, acc.logp))
338-
end

src/default_accumulators.jl

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
LogPriorAccumulator{T<:Real} <: AbstractAccumulator
3+
4+
An accumulator that tracks the cumulative log prior during model execution.
5+
6+
# Fields
7+
$(TYPEDFIELDS)
8+
"""
9+
struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator
10+
"the scalar log prior value"
11+
logp::T
12+
end
13+
14+
"""
15+
LogPriorAccumulator{T}()
16+
17+
Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero.
18+
"""
19+
LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T))
20+
LogPriorAccumulator() = LogPriorAccumulator{LogProbType}()
21+
22+
"""
23+
LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
24+
25+
An accumulator that tracks the cumulative log likelihood during model execution.
26+
27+
# Fields
28+
$(TYPEDFIELDS)
29+
"""
30+
struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
31+
"the scalar log likelihood value"
32+
logp::T
33+
end
34+
35+
"""
36+
LogLikelihoodAccumulator{T}()
37+
38+
Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero.
39+
"""
40+
LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T))
41+
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
42+
43+
"""
44+
NumProduceAccumulator{T} <: AbstractAccumulator
45+
46+
An accumulator that tracks the number of observations during model execution.
47+
48+
# Fields
49+
$(TYPEDFIELDS)
50+
"""
51+
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
52+
"the number of observations"
53+
num::T
54+
end
55+
56+
"""
57+
NumProduceAccumulator{T<:Integer}()
58+
59+
Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
60+
"""
61+
NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
62+
NumProduceAccumulator() = NumProduceAccumulator{Int}()
63+
64+
function Base.show(io::IO, acc::LogPriorAccumulator)
65+
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
66+
end
67+
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
68+
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
69+
end
70+
function Base.show(io::IO, acc::NumProduceAccumulator)
71+
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
72+
end
73+
74+
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
75+
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
76+
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
77+
78+
split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
79+
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
80+
split(acc::NumProduceAccumulator) = acc
81+
82+
function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
83+
return LogPriorAccumulator(acc.logp + acc2.logp)
84+
end
85+
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
86+
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
87+
end
88+
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
89+
return NumProduceAccumulator(max(acc.num, acc2.num))
90+
end
91+
92+
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
93+
return LogPriorAccumulator(acc1.logp + acc2.logp)
94+
end
95+
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
96+
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
97+
end
98+
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
99+
100+
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
101+
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
102+
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))
103+
104+
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
105+
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
106+
end
107+
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
108+
109+
accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc
110+
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
111+
# Note that it's important to use the loglikelihood function here, not logpdf, because
112+
# they handle vectors differently:
113+
# https://github.com/JuliaStats/Distributions.jl/issues/1972
114+
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
115+
end
116+
117+
accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
118+
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
119+
120+
function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
121+
return LogPriorAccumulator(convert(T, acc.logp))
122+
end
123+
function Base.convert(
124+
::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator
125+
) where {T}
126+
return LogLikelihoodAccumulator(convert(T, acc.logp))
127+
end
128+
function Base.convert(
129+
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
130+
) where {T}
131+
return NumProduceAccumulator(convert(T, acc.num))
132+
end
133+
134+
# TODO(mhauru)
135+
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
136+
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
137+
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
138+
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
139+
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
140+
return LogPriorAccumulator(convert(T, acc.logp))
141+
end
142+
function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T}
143+
return LogLikelihoodAccumulator(convert(T, acc.logp))
144+
end

0 commit comments

Comments
 (0)