1
1
"""
2
- LogPriorAccumulator{T<:Real} <: AbstractAccumulator
2
+ LogProbAccumulator{T} <: AbstractAccumulator
3
+
4
+ An abstract type for accumulators that hold a single scalar log probability value.
5
+
6
+ Every subtype of `LogProbAccumulator` must implement
7
+ * A method for `logp` that returns the scalar log probability value that defines it.
8
+ * A single-argument constructor that takes a `logp` value.
9
+ * `accumulator_name`, `accumulate_assume!!`, and `accumulate_observe!!` methods like any
10
+ other accumulator.
11
+
12
+ `LogProbAccumulator` provides implementations for other common functions, like convenience
13
+ constructors, `copy`, `show`, `==`, `isequal`, `hash`, `split`, and `combine`.
14
+
15
+ This type has no great conceptual significance, it just reduces code duplication between
16
+ types like LogPriorAccumulator, LogJacobianAccumulator, and LogLikelihoodAccumulator.
17
+ """
18
+ abstract type LogProbAccumulator{T<: Real } <: AbstractAccumulator end
19
+
20
+ # The first of the below methods sets AccType{T}() = AccType(zero(T)) for any
21
+ # AccType <: LogProbAccumulator{T}. The second one sets LogProbType as the default eltype T
22
+ # when calling AccType().
23
+ """
24
+ LogProbAccumulator{T}()
25
+
26
+ Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero.
27
+ """
28
+ (:: Type{AccType} )() where {T<: Real ,AccType<: LogProbAccumulator{T} } = AccType (zero (T))
29
+ (:: Type{AccType} )() where {AccType<: LogProbAccumulator } = AccType {LogProbType} ()
30
+
31
+ Base. copy (acc:: LogProbAccumulator ) = acc
32
+
33
+ function Base. show (io:: IO , acc:: LogProbAccumulator )
34
+ return print (io, " $(repr (accumulator_name (acc))) ($(repr (logp (acc))) ))" )
35
+ end
36
+
37
+ # Note that == and isequal are different, and equality under the latter should imply
38
+ # equality of hashes. Both of the below implementations are also different from the default
39
+ # implementation for structs.
40
+ function Base.:(== )(acc1:: LogProbAccumulator , acc2:: LogProbAccumulator )
41
+ return accumulator_name (acc1) === accumulator_name (acc2) && logp (acc1) == logp (acc2)
42
+ end
43
+
44
+ function Base. isequal (acc1:: LogProbAccumulator , acc2:: LogProbAccumulator )
45
+ return basetypeof (acc1) === basetypeof (acc2) && isequal (logp (acc1), logp (acc2))
46
+ end
47
+
48
+ Base. hash (acc:: T , h:: UInt ) where {T<: LogProbAccumulator } = hash ((T, logp (acc)), h)
49
+
50
+ split (:: AccType ) where {T,AccType<: LogProbAccumulator{T} } = AccType (zero (T))
51
+
52
+ function combine (acc:: LogProbAccumulator , acc2:: LogProbAccumulator )
53
+ if basetypeof (acc) != = basetypeof (acc2)
54
+ msg = " Cannot combine accumulators of different types: $(basetypeof (acc)) and $(basetypeof (acc2)) "
55
+ throw (ArgumentError (msg))
56
+ end
57
+ return basetypeof (acc)(logp (acc) + logp (acc2))
58
+ end
59
+
60
+ function Base.:+ (acc1:: LogProbAccumulator , acc2:: LogProbAccumulator )
61
+ if basetypeof (acc1) != = basetypeof (acc2)
62
+ msg = " Cannot add accumulators of different types: $(basetypeof (acc1)) and $(basetypeof (acc2)) "
63
+ throw (ArgumentError (msg))
64
+ end
65
+ return basetypeof (acc1)(logp (acc1) + logp (acc2))
66
+ end
67
+
68
+ Base. zero (acc:: T ) where {T<: LogProbAccumulator } = T (zero (logp (acc)))
69
+
70
+ function Base. convert (
71
+ :: Type{AccType} , acc:: LogProbAccumulator
72
+ ) where {T,AccType<: LogProbAccumulator{T} }
73
+ return AccType (convert (T, logp (acc)))
74
+ end
75
+
76
+ function convert_eltype (:: Type{T} , acc:: LogProbAccumulator ) where {T}
77
+ return basetypeof (acc)(convert (T, logp (acc)))
78
+ end
79
+
80
+ """
81
+ LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T}
3
82
4
83
An accumulator that tracks the cumulative log prior during model execution.
5
84
@@ -10,21 +89,22 @@ linked or not.
10
89
# Fields
11
90
$(TYPEDFIELDS)
12
91
"""
13
- struct LogPriorAccumulator{T<: Real } <: AbstractAccumulator
92
+ struct LogPriorAccumulator{T<: Real } <: LogProbAccumulator{T}
14
93
" the scalar log prior value"
15
94
logp:: T
16
95
end
17
96
18
- """
19
- LogPriorAccumulator{T}()
97
+ logp (acc:: LogPriorAccumulator ) = acc. logp
20
98
21
- Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero.
22
- """
23
- LogPriorAccumulator {T} () where {T<: Real } = LogPriorAccumulator (zero (T))
24
- LogPriorAccumulator () = LogPriorAccumulator {LogProbType} ()
99
+ accumulator_name (:: Type{<:LogPriorAccumulator} ) = :LogPrior
100
+
101
+ function accumulate_assume!! (acc:: LogPriorAccumulator , val, logjac, vn, right)
102
+ return acc + LogPriorAccumulator (logpdf (right, val))
103
+ end
104
+ accumulate_observe!! (acc:: LogPriorAccumulator , right, left, vn) = acc
25
105
26
106
"""
27
- LogJacobianAccumulator{T<:Real} <: AbstractAccumulator
107
+ LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T}
28
108
29
109
An accumulator that tracks the cumulative log Jacobian (technically,
30
110
log(abs(det(J)))) during model execution. Specifically, J refers to the
@@ -53,39 +133,44 @@ distribution to unconstrained space.
53
133
# Fields
54
134
$(TYPEDFIELDS)
55
135
"""
56
- struct LogJacobianAccumulator{T<: Real } <: AbstractAccumulator
136
+ struct LogJacobianAccumulator{T<: Real } <: LogProbAccumulator{T}
57
137
" the logabsdet of the link transform Jacobian"
58
138
logJ:: T
59
139
end
60
140
61
- """
62
- LogJacobianAccumulator{T}()
141
+ logp (acc:: LogJacobianAccumulator ) = acc. logJ
63
142
64
- Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero.
65
- """
66
- LogJacobianAccumulator {T} () where {T<: Real } = LogJacobianAccumulator (zero (T))
67
- LogJacobianAccumulator () = LogJacobianAccumulator {LogProbType} ()
143
+ accumulator_name (:: Type{<:LogJacobianAccumulator} ) = :LogJacobian
144
+
145
+ function accumulate_assume!! (acc:: LogJacobianAccumulator , val, logjac, vn, right)
146
+ return acc + LogJacobianAccumulator (logjac)
147
+ end
148
+ accumulate_observe!! (acc:: LogJacobianAccumulator , right, left, vn) = acc
68
149
69
150
"""
70
- LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
151
+ LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T}
71
152
72
153
An accumulator that tracks the cumulative log likelihood during model execution.
73
154
74
155
# Fields
75
156
$(TYPEDFIELDS)
76
157
"""
77
- struct LogLikelihoodAccumulator{T<: Real } <: AbstractAccumulator
158
+ struct LogLikelihoodAccumulator{T<: Real } <: LogProbAccumulator{T}
78
159
" the scalar log likelihood value"
79
160
logp:: T
80
161
end
81
162
82
- """
83
- LogLikelihoodAccumulator{T}()
163
+ logp (acc:: LogLikelihoodAccumulator ) = acc. logp
84
164
85
- Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero.
86
- """
87
- LogLikelihoodAccumulator {T} () where {T<: Real } = LogLikelihoodAccumulator (zero (T))
88
- LogLikelihoodAccumulator () = LogLikelihoodAccumulator {LogProbType} ()
165
+ accumulator_name (:: Type{<:LogLikelihoodAccumulator} ) = :LogLikelihood
166
+
167
+ accumulate_assume!! (acc:: LogLikelihoodAccumulator , val, logjac, vn, right) = acc
168
+ function accumulate_observe!! (acc:: LogLikelihoodAccumulator , right, left, vn)
169
+ # Note that it's important to use the loglikelihood function here, not logpdf, because
170
+ # they handle vectors differently:
171
+ # https://github.com/JuliaStats/Distributions.jl/issues/1972
172
+ return acc + LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
173
+ end
89
174
90
175
"""
91
176
VariableOrderAccumulator{T} <: AbstractAccumulator
@@ -117,85 +202,32 @@ VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
117
202
VariableOrderAccumulator (n) = VariableOrderAccumulator {typeof(n)} (n)
118
203
VariableOrderAccumulator () = VariableOrderAccumulator {Int} ()
119
204
120
- Base. copy (acc:: LogPriorAccumulator ) = acc
121
- Base. copy (acc:: LogJacobianAccumulator ) = acc
122
- Base. copy (acc:: LogLikelihoodAccumulator ) = acc
123
205
function Base. copy (acc:: VariableOrderAccumulator )
124
206
return VariableOrderAccumulator (acc. num_produce, copy (acc. order))
125
207
end
126
208
127
- function Base. show (io:: IO , acc:: LogPriorAccumulator )
128
- return print (io, " LogPriorAccumulator($(repr (acc. logp)) )" )
129
- end
130
- function Base. show (io:: IO , acc:: LogJacobianAccumulator )
131
- return print (io, " LogJacobianAccumulator($(repr (acc. logJ)) )" )
132
- end
133
- function Base. show (io:: IO , acc:: LogLikelihoodAccumulator )
134
- return print (io, " LogLikelihoodAccumulator($(repr (acc. logp)) )" )
135
- end
136
209
function Base. show (io:: IO , acc:: VariableOrderAccumulator )
137
210
return print (
138
211
io, " VariableOrderAccumulator($(repr (acc. num_produce)) , $(repr (acc. order)) )"
139
212
)
140
213
end
141
214
142
- # Note that == and isequal are different, and equality under the latter should imply
143
- # equality of hashes. Both of the below implementations are also different from the default
144
- # implementation for structs.
145
- Base.:(== )(acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator ) = acc1. logp == acc2. logp
146
- function Base.:(== )(acc1:: LogJacobianAccumulator , acc2:: LogJacobianAccumulator )
147
- return acc1. logJ == acc2. logJ
148
- end
149
- function Base.:(== )(acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
150
- return acc1. logp == acc2. logp
151
- end
152
215
function Base.:(== )(acc1:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
153
216
return acc1. num_produce == acc2. num_produce && acc1. order == acc2. order
154
217
end
155
218
156
- function Base. isequal (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
157
- return isequal (acc1. logp, acc2. logp)
158
- end
159
- function Base. isequal (acc1:: LogJacobianAccumulator , acc2:: LogJacobianAccumulator )
160
- return isequal (acc1. logJ, acc2. logJ)
161
- end
162
- function Base. isequal (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
163
- return isequal (acc1. logp, acc2. logp)
164
- end
165
219
function Base. isequal (acc1:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
166
220
return isequal (acc1. num_produce, acc2. num_produce) && isequal (acc1. order, acc2. order)
167
221
end
168
222
169
- Base. hash (acc:: LogPriorAccumulator , h:: UInt ) = hash ((LogPriorAccumulator, acc. logp), h)
170
- function Base. hash (acc:: LogJacobianAccumulator , h:: UInt )
171
- return hash ((LogJacobianAccumulator, acc. logJ), h)
172
- end
173
- function Base. hash (acc:: LogLikelihoodAccumulator , h:: UInt )
174
- return hash ((LogLikelihoodAccumulator, acc. logp), h)
175
- end
176
223
function Base. hash (acc:: VariableOrderAccumulator , h:: UInt )
177
224
return hash ((VariableOrderAccumulator, acc. num_produce, acc. order), h)
178
225
end
179
226
180
- accumulator_name (:: Type{<:LogPriorAccumulator} ) = :LogPrior
181
- accumulator_name (:: Type{<:LogJacobianAccumulator} ) = :LogJacobian
182
- accumulator_name (:: Type{<:LogLikelihoodAccumulator} ) = :LogLikelihood
183
227
accumulator_name (:: Type{<:VariableOrderAccumulator} ) = :VariableOrder
184
228
185
- split (:: LogPriorAccumulator{T} ) where {T} = LogPriorAccumulator (zero (T))
186
- split (:: LogJacobianAccumulator{T} ) where {T} = LogJacobianAccumulator (zero (T))
187
- split (:: LogLikelihoodAccumulator{T} ) where {T} = LogLikelihoodAccumulator (zero (T))
188
229
split (acc:: VariableOrderAccumulator ) = copy (acc)
189
230
190
- function combine (acc:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
191
- return LogPriorAccumulator (acc. logp + acc2. logp)
192
- end
193
- function combine (acc:: LogJacobianAccumulator , acc2:: LogJacobianAccumulator )
194
- return LogJacobianAccumulator (acc. logJ + acc2. logJ)
195
- end
196
- function combine (acc:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
197
- return LogLikelihoodAccumulator (acc. logp + acc2. logp)
198
- end
199
231
function combine (acc:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
200
232
# Note that assumptions are not allowed in parallelised blocks, and thus the
201
233
# dictionaries should be identical.
@@ -204,60 +236,16 @@ function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
204
236
)
205
237
end
206
238
207
- function Base.:+ (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
208
- return LogPriorAccumulator (acc1. logp + acc2. logp)
209
- end
210
- function Base.:+ (acc1:: LogJacobianAccumulator , acc2:: LogJacobianAccumulator )
211
- return LogJacobianAccumulator (acc1. logJ + acc2. logJ)
212
- end
213
- function Base.:+ (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
214
- return LogLikelihoodAccumulator (acc1. logp + acc2. logp)
215
- end
216
239
function increment (acc:: VariableOrderAccumulator )
217
240
return VariableOrderAccumulator (acc. num_produce + oneunit (acc. num_produce), acc. order)
218
241
end
219
242
220
- Base. zero (acc:: LogPriorAccumulator ) = LogPriorAccumulator (zero (acc. logp))
221
- Base. zero (acc:: LogJacobianAccumulator ) = LogJacobianAccumulator (zero (acc. logJ))
222
- Base. zero (acc:: LogLikelihoodAccumulator ) = LogLikelihoodAccumulator (zero (acc. logp))
223
-
224
- function accumulate_assume!! (acc:: LogPriorAccumulator , val, logjac, vn, right)
225
- return acc + LogPriorAccumulator (logpdf (right, val))
226
- end
227
- accumulate_observe!! (acc:: LogPriorAccumulator , right, left, vn) = acc
228
-
229
- function accumulate_assume!! (acc:: LogJacobianAccumulator , val, logjac, vn, right)
230
- return acc + LogJacobianAccumulator (logjac)
231
- end
232
- accumulate_observe!! (acc:: LogJacobianAccumulator , right, left, vn) = acc
233
-
234
- accumulate_assume!! (acc:: LogLikelihoodAccumulator , val, logjac, vn, right) = acc
235
- function accumulate_observe!! (acc:: LogLikelihoodAccumulator , right, left, vn)
236
- # Note that it's important to use the loglikelihood function here, not logpdf, because
237
- # they handle vectors differently:
238
- # https://github.com/JuliaStats/Distributions.jl/issues/1972
239
- return acc + LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
240
- end
241
-
242
243
function accumulate_assume!! (acc:: VariableOrderAccumulator , val, logjac, vn, right)
243
244
acc. order[vn] = acc. num_produce
244
245
return acc
245
246
end
246
247
accumulate_observe!! (acc:: VariableOrderAccumulator , right, left, vn) = increment (acc)
247
248
248
- function Base. convert (:: Type{LogPriorAccumulator{T}} , acc:: LogPriorAccumulator ) where {T}
249
- return LogPriorAccumulator (convert (T, acc. logp))
250
- end
251
- function Base. convert (
252
- :: Type{LogJacobianAccumulator{T}} , acc:: LogJacobianAccumulator
253
- ) where {T}
254
- return LogJacobianAccumulator (convert (T, acc. logJ))
255
- end
256
- function Base. convert (
257
- :: Type{LogLikelihoodAccumulator{T}} , acc:: LogLikelihoodAccumulator
258
- ) where {T}
259
- return LogLikelihoodAccumulator (convert (T, acc. logp))
260
- end
261
249
function Base. convert (
262
250
:: Type{VariableOrderAccumulator{ElType,VnType}} , acc:: VariableOrderAccumulator
263
251
) where {ElType,VnType}
273
261
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
274
262
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
275
263
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
276
- function convert_eltype (:: Type{T} , acc:: LogPriorAccumulator ) where {T}
277
- return LogPriorAccumulator (convert (T, acc. logp))
278
- end
279
- function convert_eltype (:: Type{T} , acc:: LogJacobianAccumulator ) where {T}
280
- return LogJacobianAccumulator (convert (T, acc. logJ))
281
- end
282
- function convert_eltype (:: Type{T} , acc:: LogLikelihoodAccumulator ) where {T}
283
- return LogLikelihoodAccumulator (convert (T, acc. logp))
284
- end
285
264
286
265
function default_accumulators (
287
266
:: Type{FloatT} = LogProbType, :: Type{IntT} = Int
0 commit comments