Skip to content

Commit e9bf50b

Browse files
committed
Simplify accs with LogProbAccumulator
1 parent 10de51f commit e9bf50b

File tree

2 files changed

+116
-130
lines changed

2 files changed

+116
-130
lines changed

src/default_accumulators.jl

Lines changed: 109 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,84 @@
11
"""
2-
LogPriorAccumulator{T<:Real} <: AbstractAccumulator
2+
LogProbAccumulator{T} <: AbstractAccumulator
3+
4+
An abstract type for accumulators that hold a single scalar log probability value.
5+
6+
Every subtype of `LogProbAccumulator` must implement
7+
* A method for `logp` that returns the scalar log probability value that defines it.
8+
* A single-argument constructor that takes a `logp` value.
9+
* `accumulator_name`, `accumulate_assume!!`, and `accumulate_observe!!` methods like any
10+
other accumulator.
11+
12+
`LogProbAccumulator` provides implementations for other common functions, like convenience
13+
constructors, `copy`, `show`, `==`, `isequal`, `hash`, `split`, and `combine`.
14+
15+
This type has no great conceptual significance, it just reduces code duplication between
16+
types like LogPriorAccumulator, LogJacobianAccumulator, and LogLikelihoodAccumulator.
17+
"""
18+
abstract type LogProbAccumulator{T<:Real} <: AbstractAccumulator end
19+
20+
# The first of the below methods sets AccType{T}() = AccType(zero(T)) for any
21+
# AccType <: LogProbAccumulator{T}. The second one sets LogProbType as the default eltype T
22+
# when calling AccType().
23+
"""
24+
LogProbAccumulator{T}()
25+
26+
Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero.
27+
"""
28+
(::Type{AccType})() where {T<:Real,AccType<:LogProbAccumulator{T}} = AccType(zero(T))
29+
(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType}()
30+
31+
Base.copy(acc::LogProbAccumulator) = acc
32+
33+
function Base.show(io::IO, acc::LogProbAccumulator)
34+
return print(io, "$(repr(accumulator_name(acc)))($(repr(logp(acc)))))")
35+
end
36+
37+
# Note that == and isequal are different, and equality under the latter should imply
38+
# equality of hashes. Both of the below implementations are also different from the default
39+
# implementation for structs.
40+
function Base.:(==)(acc1::LogProbAccumulator, acc2::LogProbAccumulator)
41+
return accumulator_name(acc1) === accumulator_name(acc2) && logp(acc1) == logp(acc2)
42+
end
43+
44+
function Base.isequal(acc1::LogProbAccumulator, acc2::LogProbAccumulator)
45+
return basetypeof(acc1) === basetypeof(acc2) && isequal(logp(acc1), logp(acc2))
46+
end
47+
48+
Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h)
49+
50+
split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T))
51+
52+
function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator)
53+
if basetypeof(acc) !== basetypeof(acc2)
54+
msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))"
55+
throw(ArgumentError(msg))
56+
end
57+
return basetypeof(acc)(logp(acc) + logp(acc2))
58+
end
59+
60+
function Base.:+(acc1::LogProbAccumulator, acc2::LogProbAccumulator)
61+
if basetypeof(acc1) !== basetypeof(acc2)
62+
msg = "Cannot add accumulators of different types: $(basetypeof(acc1)) and $(basetypeof(acc2))"
63+
throw(ArgumentError(msg))
64+
end
65+
return basetypeof(acc1)(logp(acc1) + logp(acc2))
66+
end
67+
68+
Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc)))
69+
70+
function Base.convert(
71+
::Type{AccType}, acc::LogProbAccumulator
72+
) where {T,AccType<:LogProbAccumulator{T}}
73+
return AccType(convert(T, logp(acc)))
74+
end
75+
76+
function convert_eltype(::Type{T}, acc::LogProbAccumulator) where {T}
77+
return basetypeof(acc)(convert(T, logp(acc)))
78+
end
79+
80+
"""
81+
LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T}
382
483
An accumulator that tracks the cumulative log prior during model execution.
584
@@ -10,21 +89,22 @@ linked or not.
1089
# Fields
1190
$(TYPEDFIELDS)
1291
"""
13-
struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator
92+
struct LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T}
1493
"the scalar log prior value"
1594
logp::T
1695
end
1796

18-
"""
19-
LogPriorAccumulator{T}()
97+
logp(acc::LogPriorAccumulator) = acc.logp
2098

21-
Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero.
22-
"""
23-
LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T))
24-
LogPriorAccumulator() = LogPriorAccumulator{LogProbType}()
99+
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
100+
101+
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
102+
return acc + LogPriorAccumulator(logpdf(right, val))
103+
end
104+
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
25105

26106
"""
27-
LogJacobianAccumulator{T<:Real} <: AbstractAccumulator
107+
LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T}
28108
29109
An accumulator that tracks the cumulative log Jacobian (technically,
30110
log(abs(det(J)))) during model execution. Specifically, J refers to the
@@ -53,39 +133,44 @@ distribution to unconstrained space.
53133
# Fields
54134
$(TYPEDFIELDS)
55135
"""
56-
struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator
136+
struct LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T}
57137
"the logabsdet of the link transform Jacobian"
58138
logJ::T
59139
end
60140

61-
"""
62-
LogJacobianAccumulator{T}()
141+
logp(acc::LogJacobianAccumulator) = acc.logJ
63142

64-
Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero.
65-
"""
66-
LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T))
67-
LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}()
143+
accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian
144+
145+
function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right)
146+
return acc + LogJacobianAccumulator(logjac)
147+
end
148+
accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc
68149

69150
"""
70-
LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
151+
LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T}
71152
72153
An accumulator that tracks the cumulative log likelihood during model execution.
73154
74155
# Fields
75156
$(TYPEDFIELDS)
76157
"""
77-
struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
158+
struct LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T}
78159
"the scalar log likelihood value"
79160
logp::T
80161
end
81162

82-
"""
83-
LogLikelihoodAccumulator{T}()
163+
logp(acc::LogLikelihoodAccumulator) = acc.logp
84164

85-
Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero.
86-
"""
87-
LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T))
88-
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
165+
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
166+
167+
accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc
168+
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
169+
# Note that it's important to use the loglikelihood function here, not logpdf, because
170+
# they handle vectors differently:
171+
# https://github.com/JuliaStats/Distributions.jl/issues/1972
172+
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
173+
end
89174

90175
"""
91176
VariableOrderAccumulator{T} <: AbstractAccumulator
@@ -117,85 +202,32 @@ VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
117202
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
118203
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
119204

120-
Base.copy(acc::LogPriorAccumulator) = acc
121-
Base.copy(acc::LogJacobianAccumulator) = acc
122-
Base.copy(acc::LogLikelihoodAccumulator) = acc
123205
function Base.copy(acc::VariableOrderAccumulator)
124206
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
125207
end
126208

127-
function Base.show(io::IO, acc::LogPriorAccumulator)
128-
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
129-
end
130-
function Base.show(io::IO, acc::LogJacobianAccumulator)
131-
return print(io, "LogJacobianAccumulator($(repr(acc.logJ)))")
132-
end
133-
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
134-
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
135-
end
136209
function Base.show(io::IO, acc::VariableOrderAccumulator)
137210
return print(
138211
io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))"
139212
)
140213
end
141214

142-
# Note that == and isequal are different, and equality under the latter should imply
143-
# equality of hashes. Both of the below implementations are also different from the default
144-
# implementation for structs.
145-
Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp
146-
function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator)
147-
return acc1.logJ == acc2.logJ
148-
end
149-
function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
150-
return acc1.logp == acc2.logp
151-
end
152215
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
153216
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
154217
end
155218

156-
function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
157-
return isequal(acc1.logp, acc2.logp)
158-
end
159-
function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator)
160-
return isequal(acc1.logJ, acc2.logJ)
161-
end
162-
function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
163-
return isequal(acc1.logp, acc2.logp)
164-
end
165219
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
166220
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
167221
end
168222

169-
Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h)
170-
function Base.hash(acc::LogJacobianAccumulator, h::UInt)
171-
return hash((LogJacobianAccumulator, acc.logJ), h)
172-
end
173-
function Base.hash(acc::LogLikelihoodAccumulator, h::UInt)
174-
return hash((LogLikelihoodAccumulator, acc.logp), h)
175-
end
176223
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
177224
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
178225
end
179226

180-
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
181-
accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian
182-
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
183227
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder
184228

185-
split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
186-
split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T))
187-
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
188229
split(acc::VariableOrderAccumulator) = copy(acc)
189230

190-
function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
191-
return LogPriorAccumulator(acc.logp + acc2.logp)
192-
end
193-
function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator)
194-
return LogJacobianAccumulator(acc.logJ + acc2.logJ)
195-
end
196-
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
197-
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
198-
end
199231
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
200232
# Note that assumptions are not allowed in parallelised blocks, and thus the
201233
# dictionaries should be identical.
@@ -204,60 +236,16 @@ function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
204236
)
205237
end
206238

207-
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
208-
return LogPriorAccumulator(acc1.logp + acc2.logp)
209-
end
210-
function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator)
211-
return LogJacobianAccumulator(acc1.logJ + acc2.logJ)
212-
end
213-
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
214-
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
215-
end
216239
function increment(acc::VariableOrderAccumulator)
217240
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
218241
end
219242

220-
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
221-
Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logJ))
222-
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
223-
224-
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
225-
return acc + LogPriorAccumulator(logpdf(right, val))
226-
end
227-
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
228-
229-
function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right)
230-
return acc + LogJacobianAccumulator(logjac)
231-
end
232-
accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc
233-
234-
accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc
235-
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
236-
# Note that it's important to use the loglikelihood function here, not logpdf, because
237-
# they handle vectors differently:
238-
# https://github.com/JuliaStats/Distributions.jl/issues/1972
239-
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
240-
end
241-
242243
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
243244
acc.order[vn] = acc.num_produce
244245
return acc
245246
end
246247
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)
247248

248-
function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
249-
return LogPriorAccumulator(convert(T, acc.logp))
250-
end
251-
function Base.convert(
252-
::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator
253-
) where {T}
254-
return LogJacobianAccumulator(convert(T, acc.logJ))
255-
end
256-
function Base.convert(
257-
::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator
258-
) where {T}
259-
return LogLikelihoodAccumulator(convert(T, acc.logp))
260-
end
261249
function Base.convert(
262250
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
263251
) where {ElType,VnType}
@@ -273,15 +261,6 @@ end
273261
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
274262
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
275263
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
276-
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
277-
return LogPriorAccumulator(convert(T, acc.logp))
278-
end
279-
function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T}
280-
return LogJacobianAccumulator(convert(T, acc.logJ))
281-
end
282-
function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T}
283-
return LogLikelihoodAccumulator(convert(T, acc.logp))
284-
end
285264

286265
function default_accumulators(
287266
::Type{FloatT}=LogProbType, ::Type{IntT}=Int

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,3 +1332,10 @@ function group_varnames_by_symbol(vns::VarNameTuple)
13321332
elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...))
13331333
return NamedTuple{syms}(elements)
13341334
end
1335+
1336+
"""
1337+
basetypeof(x)
1338+
1339+
Return `typeof(x)` stripped of its type parameters.
1340+
"""
1341+
basetypeof(x::T) where {T} = Base.typename(T).wrapper

0 commit comments

Comments
 (0)