Skip to content

Commit 1ed8cc8

Browse files
mhaurupenelopeysm
andauthored
Remove VariableOrderAccumulator and subset and merge on accs (#1005)
* Remove VariableOrderAcc and merge/subset on accs * Fix an import * Fix docs * Expand on upstream pMCMC changes * Remove set_retained_vns_del!(vi::ThreadSafeVarInfo) --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent f0ac109 commit 1ed8cc8

12 files changed

+45
-600
lines changed

HISTORY.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,35 @@ The other case where one might use `PriorContext` was to use `@addlogprob!` to a
7272
Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`.
7373
Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood.
7474

75+
### Removal of `order` and `num_produce`
76+
77+
The `VarInfo` type used to carry with it:
78+
79+
- `num_produce`, an integer which recorded the number of observe tilde-statements that had been evaluated so far; and
80+
- `order`, an integer per `VarName` which recorded the value of `num_produce` at the time that the variable was seen.
81+
82+
These fields were used in particle samplers in Turing.jl.
83+
In DynamicPPL 0.37, these fields and the associated functions have been removed:
84+
85+
- `get_num_produce`
86+
- `set_num_produce!!`
87+
- `reset_num_produce!!`
88+
- `increment_num_produce!!`
89+
- `set_retained_vns_del!`
90+
- `setorder!!`
91+
92+
Because this is one of the more arcane features of DynamicPPL, some extra explanation is warranted.
93+
94+
`num_produce` and `order`, along with the `del` flag in `VarInfo`, were used to control whether new values for variables were sampled during model execution.
95+
For example, the particle Gibbs method has a _reference particle_, for which variables are never resampled.
96+
However, if the reference particle is _forked_ (i.e., if the reference particle is selected by a resampling step multiple times and thereby copied), then the variables that have not yet been evaluated must be sampled anew to ensure that the new particle is independent of the reference particle.
97+
98+
Previousy, this was accomplished by setting the `del` flag in the `VarInfo` object for all variables with `order` greater or equal to than `num_produce`.
99+
Note that setting the `del` flag does not itself trigger a new value to be sampled; rather, it indicates that a new value should be sampled _if the variable is encountered again_.
100+
[This Turing.jl PR](https://github.com/TuringLang/Turing.jl/pull/2629) changes the implementation to set the `del` flag for _all_ variables in the `VarInfo`.
101+
Since the `del` flag only makes a difference when encountering a variable, this approach is entirely equivalent as long as the same variable is not seen multiple times in the model.
102+
The interested reader is referred to that PR for more details.
103+
75104
**Internals**
76105

77106
### Accumulators
@@ -83,7 +112,6 @@ This release overhauls how VarInfo objects track variables such as the log joint
83112
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
84113
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
85114
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
86-
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
87115
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
88116
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
89117
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.

docs/src/api.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -334,17 +334,6 @@ unset_flag!
334334
is_flagged
335335
```
336336

337-
The following functions were used for sequential Monte Carlo methods.
338-
339-
```@docs
340-
get_num_produce
341-
set_num_produce!!
342-
increment_num_produce!!
343-
reset_num_produce!!
344-
setorder!!
345-
set_retained_vns_del!
346-
```
347-
348337
```@docs
349338
Base.empty!
350339
```
@@ -369,7 +358,6 @@ DynamicPPL provides the following default accumulators.
369358
LogPriorAccumulator
370359
LogJacobianAccumulator
371360
LogLikelihoodAccumulator
372-
VariableOrderAccumulator
373361
```
374362

375363
### Common API

src/DynamicPPL.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ export AbstractVarInfo,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
5353
LogJacobianAccumulator,
54-
VariableOrderAccumulator,
5554
push!!,
5655
empty!!,
5756
subset,
@@ -72,15 +71,9 @@ export AbstractVarInfo,
7271
acclogprior!!,
7372
accloglikelihood!!,
7473
resetlogp!!,
75-
get_num_produce,
76-
set_num_produce!!,
77-
reset_num_produce!!,
78-
increment_num_produce!!,
79-
set_retained_vns_del!,
8074
is_flagged,
8175
set_flag!,
8276
unset_flag!,
83-
setorder!!,
8477
istrans,
8578
link,
8679
link!!,

src/abstract_varinfo.jl

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -441,24 +441,6 @@ function resetlogp!!(vi::AbstractVarInfo)
441441
return vi
442442
end
443443

444-
"""
445-
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
446-
447-
Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
448-
statements run before sampling `vn`.
449-
"""
450-
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
451-
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
452-
end
453-
454-
"""
455-
getorder(vi::VarInfo, vn::VarName)
456-
457-
Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
458-
run before sampling `vn`.
459-
"""
460-
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]
461-
462444
# Variables and their realizations.
463445
@doc """
464446
keys(vi::AbstractVarInfo)
@@ -509,8 +491,7 @@ function getindex_internal end
509491
@doc """
510492
empty!!(vi::AbstractVarInfo)
511493
512-
Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to
513-
zeros.
494+
Empty `vi` of variables and reset any `logp` accumulators zeros.
514495
515496
This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
516497
""" BangBang.empty!!
@@ -1068,43 +1049,6 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
10681049
return x, logpdf(dist, x) + logjac
10691050
end
10701051

1071-
"""
1072-
get_num_produce(vi::AbstractVarInfo)
1073-
1074-
Return the `num_produce` of `vi`.
1075-
"""
1076-
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce
1077-
1078-
"""
1079-
set_num_produce!!(vi::AbstractVarInfo, n::Int)
1080-
1081-
Set the `num_produce` field of `vi` to `n`.
1082-
"""
1083-
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
1084-
if hasacc(vi, Val(:VariableOrder))
1085-
acc = getacc(vi, Val(:VariableOrder))
1086-
acc = VariableOrderAccumulator(n, acc.order)
1087-
else
1088-
acc = VariableOrderAccumulator(n)
1089-
end
1090-
return setacc!!(vi, acc)
1091-
end
1092-
1093-
"""
1094-
increment_num_produce!!(vi::AbstractVarInfo)
1095-
1096-
Add 1 to `num_produce` in `vi`.
1097-
"""
1098-
increment_num_produce!!(vi::AbstractVarInfo) =
1099-
map_accumulator!!(increment, vi, Val(:VariableOrder))
1100-
1101-
"""
1102-
reset_num_produce!!(vi::AbstractVarInfo)
1103-
1104-
Reset the value of `num_produce` in `vi` to 0.
1105-
"""
1106-
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))
1107-
11081052
"""
11091053
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
11101054

src/accumulators.jl

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ To be able to work with multi-threading, it should also implement:
3030
- `split(acc::T)`
3131
- `combine(acc::T, acc2::T)`
3232
33-
If two accumulators of the same type should be merged in some non-trivial way, other than
34-
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
35-
36-
If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
37-
do something other than copy the original accumulator, then
38-
`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`
39-
4033
See the documentation for each of these functions for more details.
4134
"""
4235
abstract type AbstractAccumulator end
@@ -120,24 +113,6 @@ used by various AD backends, should implement a method for this function.
120113
"""
121114
convert_eltype(::Type, acc::AbstractAccumulator) = acc
122115

123-
"""
124-
subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})
125-
126-
Return a new accumulator that only contains the information for the `VarName`s in `vns`.
127-
128-
By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
129-
"""
130-
subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc)
131-
132-
"""
133-
merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)
134-
135-
Merge two accumulators of the same type. Returns a new accumulator of the same type.
136-
137-
By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
138-
"""
139-
Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2)
140-
141116
"""
142117
AccumulatorTuple{N,T<:NamedTuple}
143118
@@ -183,50 +158,6 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
183158
return AccumulatorTuple(convert(T, accs.nt))
184159
end
185160

186-
"""
187-
subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
188-
189-
Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
190-
"""
191-
function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
192-
return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt))
193-
end
194-
195-
"""
196-
_joint_keys(nt1::NamedTuple, nt2::NamedTuple)
197-
198-
A helper function that returns three tuples of keys given two `NamedTuple`s:
199-
The keys only in `nt1`, only in `nt2`, and in both, and in that order.
200-
201-
Implemented as a generated function to enable constant propagation of the result in `merge`.
202-
"""
203-
@generated function _joint_keys(
204-
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
205-
) where {names1,names2}
206-
only_in_nt1 = tuple(setdiff(names1, names2)...)
207-
only_in_nt2 = tuple(setdiff(names2, names1)...)
208-
in_both = tuple(intersect(names1, names2)...)
209-
return :($only_in_nt1, $only_in_nt2, $in_both)
210-
end
211-
212-
"""
213-
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
214-
215-
Merge two `AccumulatorTuple`s.
216-
217-
For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
218-
accumulators themselves. Other accumulators are copied.
219-
"""
220-
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
221-
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
222-
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
223-
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
224-
accs_in_both = (
225-
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
226-
)
227-
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
228-
end
229-
230161
"""
231162
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
232163

src/default_accumulators.jl

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -166,119 +166,12 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
166166
return acclogp(acc, Distributions.loglikelihood(right, left))
167167
end
168168

169-
"""
170-
VariableOrderAccumulator{T} <: AbstractAccumulator
171-
172-
An accumulator that tracks the order of variables in a `VarInfo`.
173-
174-
This doesn't track the full ordering, but rather how many observations have taken place
175-
before the assume statement for each variable. This is needed for particle methods, where
176-
the model is segmented into parts by each observation, and we need to know which part each
177-
assume statement is in.
178-
179-
# Fields
180-
$(TYPEDFIELDS)
181-
"""
182-
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
183-
"the number of observations"
184-
num_produce::Eltype
185-
"mapping of variable names to their order in the model"
186-
order::Dict{VNType,Eltype}
187-
end
188-
189-
"""
190-
VariableOrderAccumulator{T<:Integer}(n=zero(T))
191-
192-
Create a new `VariableOrderAccumulator` with the number of observations set to `n`.
193-
"""
194-
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
195-
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
196-
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
197-
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
198-
199-
function Base.copy(acc::VariableOrderAccumulator)
200-
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
201-
end
202-
203-
function Base.show(io::IO, acc::VariableOrderAccumulator)
204-
return print(
205-
io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))"
206-
)
207-
end
208-
209-
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
210-
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
211-
end
212-
213-
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
214-
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
215-
end
216-
217-
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
218-
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
219-
end
220-
221-
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder
222-
223-
split(acc::VariableOrderAccumulator) = copy(acc)
224-
225-
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
226-
# Note that assumptions are not allowed in parallelised blocks, and thus the
227-
# dictionaries should be identical.
228-
return VariableOrderAccumulator(
229-
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
230-
)
231-
end
232-
233-
function increment(acc::VariableOrderAccumulator)
234-
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
235-
end
236-
237-
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
238-
acc.order[vn] = acc.num_produce
239-
return acc
240-
end
241-
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)
242-
243-
function Base.convert(
244-
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
245-
) where {ElType,VnType}
246-
order = Dict{VnType,ElType}()
247-
for (k, v) in acc.order
248-
order[convert(VnType, k)] = convert(ElType, v)
249-
end
250-
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
251-
end
252-
253-
# TODO(mhauru)
254-
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
255-
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
256-
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
257-
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
258-
259169
function default_accumulators(
260170
::Type{FloatT}=LogProbType, ::Type{IntT}=Int
261171
) where {FloatT,IntT}
262172
return AccumulatorTuple(
263173
LogPriorAccumulator{FloatT}(),
264174
LogJacobianAccumulator{FloatT}(),
265175
LogLikelihoodAccumulator{FloatT}(),
266-
VariableOrderAccumulator{IntT}(),
267176
)
268177
end
269-
270-
function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName})
271-
order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order)
272-
return VariableOrderAccumulator(acc.num_produce, order)
273-
end
274-
275-
"""
276-
merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
277-
278-
Merge two `VariableOrderAccumulator` instances.
279-
280-
The `num_produce` field of the return value is the `num_produce` of `acc2`.
281-
"""
282-
function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
283-
return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order))
284-
end

0 commit comments

Comments
 (0)