Skip to content

Commit cb1c6c6

Browse files
committed
Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce
1 parent c1e90f7 commit cb1c6c6

File tree

13 files changed

+201
-151
lines changed

13 files changed

+201
-151
lines changed

HISTORY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:
1010

11-
- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPrior(),))`.
11+
- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`.
1212
- `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself.
1313
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
1414
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.

docs/src/api.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,9 @@ AbstractAccumulator
356356
DynamicPPL provides the following default accumulators.
357357

358358
```@docs
359-
LogPrior
360-
LogLikelihood
361-
NumProduce
359+
LogPriorAccumulator
360+
LogLikelihoodAccumulator
361+
NumProduceAccumulator
362362
```
363363

364364
### Common API

src/DynamicPPL.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ export AbstractVarInfo,
4848
VarInfo,
4949
SimpleVarInfo,
5050
AbstractAccumulator,
51-
LogLikelihood,
52-
LogPrior,
53-
NumProduce,
51+
LogLikelihoodAccumulator,
52+
LogPriorAccumulator,
53+
NumProduceAccumulator,
5454
push!!,
5555
empty!!,
5656
subset,

src/abstract_varinfo.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Set the log of the prior probability of the parameters sampled in `vi` to `logp`
194194
195195
See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref).
196196
"""
197-
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPrior(logp))
197+
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))
198198

199199
"""
200200
setloglikelihood!!(vi::AbstractVarInfo, logp)
@@ -203,7 +203,7 @@ Set the log of the likelihood probability of the observed data sampled in `vi` t
203203
204204
See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref).
205205
"""
206-
setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp))
206+
setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp))
207207

208208
"""
209209
setlogp!!(vi::AbstractVarInfo, logp::NamedTuple)
@@ -303,7 +303,7 @@ Add `logp` to the value of the log of the prior probability in `vi`.
303303
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref).
304304
"""
305305
function acclogprior!!(vi::AbstractVarInfo, logp)
306-
return map_accumulator!!(acc -> acc + LogPrior(logp), vi, Val(:LogPrior))
306+
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
307307
end
308308

309309
"""
@@ -314,7 +314,9 @@ Add `logp` to the value of the log of the likelihood in `vi`.
314314
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref).
315315
"""
316316
function accloglikelihood!!(vi::AbstractVarInfo, logp)
317-
return map_accumulator!!(acc -> acc + LogLikelihood(logp), vi, Val(:LogLikelihood))
317+
return map_accumulator!!(
318+
acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood)
319+
)
318320
end
319321

320322
"""

src/accumulators.jl

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -193,124 +193,146 @@ end
193193
# END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS
194194

195195
"""
196-
LogPrior{T} <: AbstractAccumulator
196+
LogPriorAccumulator{T} <: AbstractAccumulator
197197
198198
An accumulator that tracks the cumulative log prior during model execution.
199199
200200
# Fields
201201
$(TYPEDFIELDS)
202202
"""
203-
struct LogPrior{T} <: AbstractAccumulator
203+
struct LogPriorAccumulator{T} <: AbstractAccumulator
204204
"the scalar log prior value"
205205
logp::T
206206
end
207207

208208
"""
209-
LogPrior{T}()
209+
LogPriorAccumulator{T}()
210210
211-
Create a new `LogPrior` accumulator with the log prior initialized to zero.
211+
Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero.
212212
"""
213-
LogPrior{T}() where {T} = LogPrior(zero(T))
214-
LogPrior() = LogPrior{LogProbType}()
213+
LogPriorAccumulator{T}() where {T} = LogPriorAccumulator(zero(T))
214+
LogPriorAccumulator() = LogPriorAccumulator{LogProbType}()
215215

216216
"""
217-
LogLikelihood{T} <: AbstractAccumulator
217+
LogLikelihoodAccumulator{T} <: AbstractAccumulator
218218
219219
An accumulator that tracks the cumulative log likelihood during model execution.
220220
221221
# Fields
222222
$(TYPEDFIELDS)
223223
"""
224-
struct LogLikelihood{T} <: AbstractAccumulator
224+
struct LogLikelihoodAccumulator{T} <: AbstractAccumulator
225225
"the scalar log likelihood value"
226226
logp::T
227227
end
228228

229229
"""
230-
LogLikelihood{T}()
230+
LogLikelihoodAccumulator{T}()
231231
232-
Create a new `LogLikelihood` accumulator with the log likelihood initialized to zero.
232+
Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero.
233233
"""
234-
LogLikelihood{T}() where {T} = LogLikelihood(zero(T))
235-
LogLikelihood() = LogLikelihood{LogProbType}()
234+
LogLikelihoodAccumulator{T}() where {T} = LogLikelihoodAccumulator(zero(T))
235+
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
236236

237237
"""
238-
NumProduce{T} <: AbstractAccumulator
238+
NumProduceAccumulator{T} <: AbstractAccumulator
239239
240240
An accumulator that tracks the number of observations during model execution.
241241
242242
# Fields
243243
$(TYPEDFIELDS)
244244
"""
245-
struct NumProduce{T<:Integer} <: AbstractAccumulator
245+
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
246246
"the number of observations"
247247
num::T
248248
end
249249

250250
"""
251-
NumProduce{T<:Integer}()
251+
NumProduceAccumulator{T<:Integer}()
252252
253-
Create a new `NumProduce` accumulator with the number of observations initialized to zero.
253+
Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
254254
"""
255-
NumProduce{T}() where {T} = NumProduce(zero(T))
256-
NumProduce() = NumProduce{Int}()
255+
NumProduceAccumulator{T}() where {T} = NumProduceAccumulator(zero(T))
256+
NumProduceAccumulator() = NumProduceAccumulator{Int}()
257257

258-
Base.show(io::IO, acc::LogPrior) = print(io, "LogPrior($(repr(acc.logp)))")
259-
Base.show(io::IO, acc::LogLikelihood) = print(io, "LogLikelihood($(repr(acc.logp)))")
260-
Base.show(io::IO, acc::NumProduce) = print(io, "NumProduce($(repr(acc.num)))")
258+
function Base.show(io::IO, acc::LogPriorAccumulator)
259+
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
260+
end
261+
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
262+
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
263+
end
264+
function Base.show(io::IO, acc::NumProduceAccumulator)
265+
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
266+
end
261267

262-
accumulator_name(::Type{<:LogPrior}) = :LogPrior
263-
accumulator_name(::Type{<:LogLikelihood}) = :LogLikelihood
264-
accumulator_name(::Type{<:NumProduce}) = :NumProduce
268+
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
269+
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
270+
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
265271

266-
split(::LogPrior{T}) where {T} = LogPrior(zero(T))
267-
split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T))
268-
split(acc::NumProduce) = acc
272+
split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
273+
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
274+
split(acc::NumProduceAccumulator) = acc
269275

270-
combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp)
271-
combine(acc::LogLikelihood, acc2::LogLikelihood) = LogLikelihood(acc.logp + acc2.logp)
272-
function combine(acc::NumProduce, acc2::NumProduce)
273-
return NumProduce(max(acc.num, acc2.num))
276+
function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
277+
return LogPriorAccumulator(acc.logp + acc2.logp)
278+
end
279+
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
280+
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
281+
end
282+
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
283+
return NumProduceAccumulator(max(acc.num, acc2.num))
274284
end
275285

276-
Base.:+(acc1::LogPrior, acc2::LogPrior) = LogPrior(acc1.logp + acc2.logp)
277-
Base.:+(acc1::LogLikelihood, acc2::LogLikelihood) = LogLikelihood(acc1.logp + acc2.logp)
278-
increment(acc::NumProduce) = NumProduce(acc.num + oneunit(acc.num))
286+
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
287+
return LogPriorAccumulator(acc1.logp + acc2.logp)
288+
end
289+
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
290+
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
291+
end
292+
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
279293

280-
Base.zero(acc::LogPrior) = LogPrior(zero(acc.logp))
281-
Base.zero(acc::LogLikelihood) = LogLikelihood(zero(acc.logp))
282-
Base.zero(acc::NumProduce) = NumProduce(zero(acc.num))
294+
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
295+
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
296+
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))
283297

284-
function accumulate_assume!!(acc::LogPrior, val, logjac, vn, right)
285-
return acc + LogPrior(logpdf(right, val) + logjac)
298+
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
299+
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
286300
end
287-
accumulate_observe!!(acc::LogPrior, right, left, vn) = acc
301+
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
288302

289-
accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc
290-
function accumulate_observe!!(acc::LogLikelihood, right, left, vn)
303+
accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc
304+
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
291305
# Note that it's important to use the loglikelihood function here, not logpdf, because
292306
# they handle vectors differently:
293307
# https://github.com/JuliaStats/Distributions.jl/issues/1972
294-
return acc + LogLikelihood(Distributions.loglikelihood(right, left))
308+
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
295309
end
296310

297-
accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc
298-
accumulate_observe!!(acc::NumProduce, right, left, vn) = increment(acc)
311+
accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
312+
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
299313

300-
Base.convert(::Type{LogPrior{T}}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp))
301-
function Base.convert(::Type{LogLikelihood{T}}, acc::LogLikelihood) where {T}
302-
return LogLikelihood(convert(T, acc.logp))
314+
function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
315+
return LogPriorAccumulator(convert(T, acc.logp))
303316
end
304-
function Base.convert(::Type{NumProduce{T}}, acc::NumProduce) where {T}
305-
return NumProduce(convert(T, acc.num))
317+
function Base.convert(
318+
::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator
319+
) where {T}
320+
return LogLikelihoodAccumulator(convert(T, acc.logp))
321+
end
322+
function Base.convert(
323+
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
324+
) where {T}
325+
return NumProduceAccumulator(convert(T, acc.num))
306326
end
307327

308328
# TODO(mhauru)
309-
# We ignore the convert_eltype calls for NumProduce, by letting them fallback on
329+
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
310330
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
311-
# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is
331+
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
312332
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
313-
convert_eltype(::Type{T}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp))
314-
function convert_eltype(::Type{T}, acc::LogLikelihood) where {T}
315-
return LogLikelihood(convert(T, acc.logp))
333+
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
334+
return LogPriorAccumulator(convert(T, acc.logp))
335+
end
336+
function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T}
337+
return LogLikelihoodAccumulator(convert(T, acc.logp))
316338
end

src/logdensityfunction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ julia> LogDensityProblems.logdensity(f, [0.0])
7979
-2.3378770664093453
8080
8181
julia> # LogDensityFunction respects the accumulators in VarInfo:
82-
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPrior(),)));
82+
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
8383
8484
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
8585
true

src/model.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,8 +1058,11 @@ See also [`logjoint`](@ref) and [`loglikelihood`](@ref).
10581058
"""
10591059
function logprior(model::Model, varinfo::AbstractVarInfo)
10601060
# Remove other accumulators from varinfo, since they are unnecessary.
1061-
logprior =
1062-
hasacc(varinfo, Val(:LogPrior)) ? getacc(varinfo, Val(:LogPrior)) : LogPrior()
1061+
logprior = if hasacc(varinfo, Val(:LogPrior))
1062+
getacc(varinfo, Val(:LogPrior))
1063+
else
1064+
LogPriorAccumulator()
1065+
end
10631066
varinfo = setaccs!!(deepcopy(varinfo), (logprior,))
10641067
return getlogprior(last(evaluate!!(model, varinfo, DefaultContext())))
10651068
end
@@ -1112,7 +1115,7 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
11121115
loglikelihood = if hasacc(varinfo, Val(:LogLikelihood))
11131116
getacc(varinfo, Val(:LogLikelihood))
11141117
else
1115-
LogLikelihood()
1118+
LogLikelihoodAccumulator()
11161119
end
11171120
varinfo = setaccs!!(deepcopy(varinfo), (loglikelihood,))
11181121
return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext())))

src/pointwise_logdensities.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function accumulate_assume!!(
6969
# T is the element type of the vectors that are the values of `acc.logps`. Usually
7070
# it's LogProbType.
7171
T = eltype(last(fieldtypes(eltype(acc.logps))))
72-
subacc = accumulate_assume!!(LogPrior{T}(), val, logjac, vn, right)
72+
subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right)
7373
push!(acc, vn, subacc.logp)
7474
end
7575
return acc
@@ -87,7 +87,7 @@ function accumulate_observe!!(
8787
# T is the element type of the vectors that are the values of `acc.logps`. Usually
8888
# it's LogProbType.
8989
T = eltype(last(fieldtypes(eltype(acc.logps))))
90-
subacc = accumulate_observe!!(LogLikelihood{T}(), right, left, vn)
90+
subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn)
9191
push!(acc, vn, subacc.logp)
9292
end
9393
return acc

src/simple_varinfo.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ Evaluation in transformed space of course also works:
125125
126126
```jldoctest simplevarinfo-general
127127
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
128-
Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihood(0.0), LogPrior = LogPrior(0.0)))
128+
Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0)))
129129
130130
julia> # (✓) Positive probability mass on negative numbers!
131131
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
132132
-1.3678794411714423
133133
134134
julia> # While if we forget to indicate that it's transformed:
135135
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
136-
SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihood(0.0), LogPrior = LogPrior(0.0)))
136+
SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0)))
137137
138138
julia> # (✓) No probability mass on negative numbers!
139139
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
@@ -204,7 +204,9 @@ function SimpleVarInfo(values, accs)
204204
return SimpleVarInfo(values, accs, NoTransformation())
205205
end
206206
function SimpleVarInfo{T}(values) where {T<:Real}
207-
return SimpleVarInfo(values, AccumulatorTuple(LogLikelihood{T}(), LogPrior{T}()))
207+
return SimpleVarInfo(
208+
values, AccumulatorTuple(LogLikelihoodAccumulator{T}(), LogPriorAccumulator{T}())
209+
)
208210
end
209211
function SimpleVarInfo(values)
210212
return SimpleVarInfo{LogProbType}(values)

src/transforming.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function _transform!!(
5252
model::Model,
5353
)
5454
# To transform using DynamicTransformationContext, we evaluate the model, but we do not
55-
# need to use any accumulators other than LogPrior (which is affected by the Jacobian of
55+
# need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of
5656
# the transformation).
5757
accs = getaccs(vi)
5858
has_logprior = haskey(accs, Val(:LogPrior))

0 commit comments

Comments
 (0)