Skip to content

Commit b1241b1

Browse files
authored
resetaccs!! before evaluating (#1013)
* resetaccs * Add a test, plus a bunch of == methods * Fix tests * Fix test (properly)
1 parent ea6b6de commit b1241b1

17 files changed

+177
-78
lines changed

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ acclogprior!!
380380
getloglikelihood
381381
setloglikelihood!!
382382
accloglikelihood!!
383-
resetlogp!!
384383
```
385384

386385
#### Variables and their realizations

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ export AbstractVarInfo,
7070
acclogjac!!,
7171
acclogprior!!,
7272
accloglikelihood!!,
73-
resetlogp!!,
7473
is_flagged,
7574
set_flag!,
7675
unset_flag!,

src/abstract_varinfo.jl

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,15 @@ function map_accumulators!!(func::Function, vi::AbstractVarInfo)
330330
return setaccs!!(vi, map(func, getaccs(vi)))
331331
end
332332

333+
"""
334+
resetaccs!!(vi::AbstractVarInfo)
335+
336+
Reset the values of all accumulators, using [`reset`](@ref).
337+
"""
338+
function resetaccs!!(vi::AbstractVarInfo)
339+
return setaccs!!(vi, map(reset, getaccs(vi)))
340+
end
341+
333342
"""
334343
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname}
335344
@@ -423,24 +432,6 @@ function acclogp!!(vi::AbstractVarInfo, logp::Number)
423432
return accloglikelihood!!(vi, logp)
424433
end
425434

426-
"""
427-
resetlogp!!(vi::AbstractVarInfo)
428-
429-
Reset the values of the log probabilities (prior and likelihood) in `vi` to zero.
430-
"""
431-
function resetlogp!!(vi::AbstractVarInfo)
432-
if hasacc(vi, Val(:LogPrior))
433-
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
434-
end
435-
if hasacc(vi, Val(:LogJacobian))
436-
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
437-
end
438-
if hasacc(vi, Val(:LogLikelihood))
439-
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
440-
end
441-
return vi
442-
end
443-
444435
# Variables and their realizations.
445436
@doc """
446437
keys(vi::AbstractVarInfo)
@@ -491,7 +482,7 @@ function getindex_internal end
491482
@doc """
492483
empty!!(vi::AbstractVarInfo)
493484
494-
Empty `vi` of variables and reset any `logp` accumulators zeros.
485+
Empty `vi` of variables and reset all accumulators.
495486
496487
This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
497488
""" BangBang.empty!!

src/accumulators.jl

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
1313
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
1414
- `accumulate_observe!!(acc::T, dist, val, vn)`
1515
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
16+
- `reset(acc::T)`
1617
- `Base.copy(acc::T)`
1718
1819
In these functions:
@@ -50,9 +51,7 @@ accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc))
5051
5152
Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`.
5253
53-
`vn` is the name of the variable being observed, `left` is the value of the variable, and
54-
`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case
55-
of literal observations like `0.0 ~ Normal()`.
54+
See [`AbstractAccumulator`](@ref) for the meaning of the arguments.
5655
5756
`accumulate_observe!!` may mutate `acc`, but not any of the other arguments.
5857
@@ -65,26 +64,45 @@ function accumulate_observe!! end
6564
6665
Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`.
6766
68-
`vn` is the name of the variable being assumed, `val` is the value of the variable (in the
69-
original, unlinked space), and `right` is the distribution on the RHS of the tilde
70-
statement. `logjac` is the log determinant of the Jacobian of the transformation that was
71-
done to convert the value of `vn` as it was given to `val`: for example, if the sampler is
72-
operating in linked (Euclidean) space, then logjac will be nonzero.
67+
See [`AbstractAccumulator`](@ref) for the meaning of the arguments.
7368
7469
`accumulate_assume!!` may mutate `acc`, but not any of the other arguments.
7570
7671
See also: [`accumulate_observe!!`](@ref)
7772
"""
7873
function accumulate_assume!! end
7974

75+
"""
76+
reset(acc::AbstractAccumulator)
77+
78+
Return a new accumulator like `acc`, but with its contents reset to the state that they
79+
should be at the beginning of model evaluation.
80+
81+
Note that this may in general have very similar behaviour to [`split`](@ref), and may share
82+
the same implementation, but the difference is that `split` may in principle happen at any
83+
stage during model evaluation, whereas `reset` is only called at the beginning of model
84+
evaluation.
85+
"""
86+
function reset end
87+
88+
@doc """
89+
Base.copy(acc::AbstractAccumulator)
90+
91+
Create a new accumulator that is a copy of `acc`, without aliasing (i.e., this should
92+
behave conceptually like a `deepcopy`).
93+
""" Base.copy
94+
8095
"""
8196
split(acc::AbstractAccumulator)
8297
83-
Return a new accumulator like `acc` but empty.
98+
Return a new accumulator like `acc` suitable for use in a forked thread.
99+
100+
The returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is
101+
used in the context of multi-threading where different threads may accumulate independently
102+
and the results are then combined.
84103
85-
The precise meaning of "empty" is that that the returned value should be such that
86-
`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading
87-
where different threads may accumulate independently and the results are then combined.
104+
Note that this may in general have very similar behaviour to [`reset`](@ref), but is
105+
semantically different. See [`reset`](@ref) for more details.
88106
89107
See also: [`combine`](@ref)
90108
"""

src/debug_utils.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,22 @@ end
156156
const _DEBUG_ACC_NAME = :Debug
157157
DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME
158158

159-
function split(acc::DebugAccumulator)
159+
function Base.:(==)(acc1::DebugAccumulator, acc2::DebugAccumulator)
160+
return (
161+
acc1.varnames_seen == acc2.varnames_seen &&
162+
acc1.statements == acc2.statements &&
163+
acc1.error_on_failure == acc2.error_on_failure
164+
)
165+
end
166+
167+
function _zero(acc::DebugAccumulator)
160168
return DebugAccumulator(
161169
OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure
162170
)
163171
end
164-
function combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
172+
DynamicPPL.reset(acc::DebugAccumulator) = _zero(acc)
173+
DynamicPPL.split(acc::DebugAccumulator) = _zero(acc)
174+
function DynamicPPL.combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
165175
return DebugAccumulator(
166176
merge(acc1.varnames_seen, acc2.varnames_seen),
167177
vcat(acc1.statements, acc2.statements),
@@ -416,7 +426,7 @@ function check_model_and_trace(
416426
issuccess = check_model_pre_evaluation(model)
417427

418428
# Force single-threaded execution.
419-
DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
429+
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
420430

421431
# Perform checks after evaluating the model.
422432
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))

src/default_accumulators.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ end
4747

4848
Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h)
4949

50-
split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T))
51-
50+
_zero(::Tacc) where {Tlogp,Tacc<:LogProbAccumulator{Tlogp}} = Tacc(zero(Tlogp))
51+
reset(acc::LogProbAccumulator) = _zero(acc)
52+
split(acc::LogProbAccumulator) = _zero(acc)
5253
function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator)
5354
if basetypeof(acc) !== basetypeof(acc2)
5455
msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))"
@@ -59,8 +60,6 @@ end
5960

6061
acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val)
6162

62-
Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc)))
63-
6463
function Base.convert(
6564
::Type{AccType}, acc::LogProbAccumulator
6665
) where {T,AccType<:LogProbAccumulator{T}}

src/extract_priors.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ end
1010

1111
accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator
1212

13-
split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
13+
function Base.:(==)(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
14+
return acc1.priors == acc2.priors
15+
end
16+
17+
_zero(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
18+
reset(acc::PriorDistributionAccumulator) = _zero(acc)
19+
split(acc::PriorDistributionAccumulator) = _zero(acc)
1420
function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
1521
return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors))
1622
end

src/model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
884884
See also: [`evaluate_threadsafe!!`](@ref)
885885
"""
886886
function evaluate_threadunsafe!!(model, varinfo)
887-
return _evaluate!!(model, resetlogp!!(varinfo))
887+
return _evaluate!!(model, resetaccs!!(varinfo))
888888
end
889889

890890
"""
@@ -899,7 +899,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
899899
See also: [`evaluate_threadunsafe!!`](@ref)
900900
"""
901901
function evaluate_threadsafe!!(model, varinfo)
902-
wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo))
902+
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
903903
result, wrapper_new = _evaluate!!(model, wrapper)
904904
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
905905
# will return the underlying VI, which is a bit counterintuitive (because

src/pointwise_logdensities.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
3232
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
3333
end
3434

35+
function Base.:(==)(
36+
acc1::PointwiseLogProbAccumulator{wlp1}, acc2::PointwiseLogProbAccumulator{wlp2}
37+
) where {wlp1,wlp2}
38+
return (wlp1 == wlp2 && acc1.logps == acc2.logps)
39+
end
40+
3541
function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
3642
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
3743
end
@@ -56,10 +62,11 @@ function accumulator_name(
5662
return Symbol("PointwiseLogProbAccumulator{$whichlogprob}")
5763
end
5864

59-
function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
65+
function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
6066
return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps))
6167
end
62-
68+
reset(acc::PointwiseLogProbAccumulator) = _zero(acc)
69+
split(acc::PointwiseLogProbAccumulator) = _zero(acc)
6370
function combine(
6471
acc::PointwiseLogProbAccumulator{whichlogprob},
6572
acc2::PointwiseLogProbAccumulator{whichlogprob},
@@ -223,23 +230,37 @@ function pointwise_logdensities(
223230
# Get the data by executing the model once
224231
vi = VarInfo(model)
225232

233+
# This accumulator tracks the pointwise log-probabilities in a single iteration.
226234
AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType}
227235
vi = setaccs!!(vi, (AccType(),))
228236

229237
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
238+
239+
# Maintain a separate accumulator that isn't tied to a VarInfo but rather
240+
# tracks _all_ iterations.
241+
all_logps = AccType()
230242
for (sample_idx, chain_idx) in iters
231243
# Update the values
232244
setval!(vi, chain, sample_idx, chain_idx)
233245

234246
# Execute model
235247
vi = last(evaluate!!(model, vi))
248+
249+
# Get the log-probabilities
250+
this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps
251+
252+
# Merge into main acc
253+
for (varname, this_lp) in this_iter_logps
254+
# Because `this_lp` is obtained from one model execution, it should only
255+
# contain one variable, hence `only()`.
256+
push!(all_logps, varname, only(this_lp))
257+
end
236258
end
237259

238-
logps = getacc(vi, Val(accumulator_name(AccType))).logps
239260
niters = size(chain, 1)
240261
nchains = size(chain, 3)
241262
logdensities = OrderedDict(
242-
varname => reshape(vals, niters, nchains) for (varname, vals) in logps
263+
varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps
243264
)
244265
return logdensities
245266
end

src/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector)
287287
end
288288

289289
function BangBang.empty!!(vi::SimpleVarInfo)
290-
return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values))
290+
return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values))
291291
end
292292
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)
293293

0 commit comments

Comments
 (0)