Skip to content

Commit 37e7b34

Browse files
committed
Remove VariableOrderAcc and merge/subset on accs
1 parent f0ac109 commit 37e7b34

12 files changed

+27
-594
lines changed

HISTORY.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ The other case where one might use `PriorContext` was to use `@addlogprob!` to a
7272
Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`.
7373
Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood.
7474

75+
### Removal of `order` and `num_produce`
76+
77+
The `VarInfo` type used to carry with it information about the order in which variables had been evaluated, and a variable called `num_produce` for where in this evaluation the `VarInfo` was at. This was used in particle samplers in Turing.jl. The particle sampler code has been simplified and no longer needs this functionality, and thus we remove it from DynamicPPL. The following exported functions are now gone:
78+
79+
- `get_num_produce`
80+
- `set_num_produce!!`
81+
- `reset_num_produce!!`
82+
- `increment_num_produce!!`
83+
- `set_retained_vns_del!`
84+
- `setorder!!`
85+
7586
**Internals**
7687

7788
### Accumulators
@@ -83,7 +94,6 @@ This release overhauls how VarInfo objects track variables such as the log joint
8394
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
8495
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
8596
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
86-
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
8797
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
8898
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
8999
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.

docs/src/api.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -334,17 +334,6 @@ unset_flag!
334334
is_flagged
335335
```
336336

337-
The following functions were used for sequential Monte Carlo methods.
338-
339-
```@docs
340-
get_num_produce
341-
set_num_produce!!
342-
increment_num_produce!!
343-
reset_num_produce!!
344-
setorder!!
345-
set_retained_vns_del!
346-
```
347-
348337
```@docs
349338
Base.empty!
350339
```

src/DynamicPPL.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ export AbstractVarInfo,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
5353
LogJacobianAccumulator,
54-
VariableOrderAccumulator,
5554
push!!,
5655
empty!!,
5756
subset,
@@ -72,15 +71,9 @@ export AbstractVarInfo,
7271
acclogprior!!,
7372
accloglikelihood!!,
7473
resetlogp!!,
75-
get_num_produce,
76-
set_num_produce!!,
77-
reset_num_produce!!,
78-
increment_num_produce!!,
79-
set_retained_vns_del!,
8074
is_flagged,
8175
set_flag!,
8276
unset_flag!,
83-
setorder!!,
8477
istrans,
8578
link,
8679
link!!,

src/abstract_varinfo.jl

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -441,24 +441,6 @@ function resetlogp!!(vi::AbstractVarInfo)
441441
return vi
442442
end
443443

444-
"""
445-
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
446-
447-
Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
448-
statements run before sampling `vn`.
449-
"""
450-
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
451-
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
452-
end
453-
454-
"""
455-
getorder(vi::VarInfo, vn::VarName)
456-
457-
Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
458-
run before sampling `vn`.
459-
"""
460-
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]
461-
462444
# Variables and their realizations.
463445
@doc """
464446
keys(vi::AbstractVarInfo)
@@ -509,8 +491,7 @@ function getindex_internal end
509491
@doc """
510492
empty!!(vi::AbstractVarInfo)
511493
512-
Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to
513-
zeros.
494+
Empty `vi` of variables and reset any `logp` accumulators zeros.
514495
515496
This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
516497
""" BangBang.empty!!
@@ -1068,43 +1049,6 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
10681049
return x, logpdf(dist, x) + logjac
10691050
end
10701051

1071-
"""
1072-
get_num_produce(vi::AbstractVarInfo)
1073-
1074-
Return the `num_produce` of `vi`.
1075-
"""
1076-
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce
1077-
1078-
"""
1079-
set_num_produce!!(vi::AbstractVarInfo, n::Int)
1080-
1081-
Set the `num_produce` field of `vi` to `n`.
1082-
"""
1083-
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
1084-
if hasacc(vi, Val(:VariableOrder))
1085-
acc = getacc(vi, Val(:VariableOrder))
1086-
acc = VariableOrderAccumulator(n, acc.order)
1087-
else
1088-
acc = VariableOrderAccumulator(n)
1089-
end
1090-
return setacc!!(vi, acc)
1091-
end
1092-
1093-
"""
1094-
increment_num_produce!!(vi::AbstractVarInfo)
1095-
1096-
Add 1 to `num_produce` in `vi`.
1097-
"""
1098-
increment_num_produce!!(vi::AbstractVarInfo) =
1099-
map_accumulator!!(increment, vi, Val(:VariableOrder))
1100-
1101-
"""
1102-
reset_num_produce!!(vi::AbstractVarInfo)
1103-
1104-
Reset the value of `num_produce` in `vi` to 0.
1105-
"""
1106-
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))
1107-
11081052
"""
11091053
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
11101054

src/accumulators.jl

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ To be able to work with multi-threading, it should also implement:
3030
- `split(acc::T)`
3131
- `combine(acc::T, acc2::T)`
3232
33-
If two accumulators of the same type should be merged in some non-trivial way, other than
34-
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
35-
36-
If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
37-
do something other than copy the original accumulator, then
38-
`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`
39-
4033
See the documentation for each of these functions for more details.
4134
"""
4235
abstract type AbstractAccumulator end
@@ -120,24 +113,6 @@ used by various AD backends, should implement a method for this function.
120113
"""
121114
convert_eltype(::Type, acc::AbstractAccumulator) = acc
122115

123-
"""
124-
subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})
125-
126-
Return a new accumulator that only contains the information for the `VarName`s in `vns`.
127-
128-
By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
129-
"""
130-
subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc)
131-
132-
"""
133-
merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)
134-
135-
Merge two accumulators of the same type. Returns a new accumulator of the same type.
136-
137-
By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
138-
"""
139-
Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2)
140-
141116
"""
142117
AccumulatorTuple{N,T<:NamedTuple}
143118
@@ -183,50 +158,6 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
183158
return AccumulatorTuple(convert(T, accs.nt))
184159
end
185160

186-
"""
187-
subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
188-
189-
Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
190-
"""
191-
function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
192-
return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt))
193-
end
194-
195-
"""
196-
_joint_keys(nt1::NamedTuple, nt2::NamedTuple)
197-
198-
A helper function that returns three tuples of keys given two `NamedTuple`s:
199-
The keys only in `nt1`, only in `nt2`, and in both, and in that order.
200-
201-
Implemented as a generated function to enable constant propagation of the result in `merge`.
202-
"""
203-
@generated function _joint_keys(
204-
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
205-
) where {names1,names2}
206-
only_in_nt1 = tuple(setdiff(names1, names2)...)
207-
only_in_nt2 = tuple(setdiff(names2, names1)...)
208-
in_both = tuple(intersect(names1, names2)...)
209-
return :($only_in_nt1, $only_in_nt2, $in_both)
210-
end
211-
212-
"""
213-
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
214-
215-
Merge two `AccumulatorTuple`s.
216-
217-
For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
218-
accumulators themselves. Other accumulators are copied.
219-
"""
220-
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
221-
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
222-
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
223-
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
224-
accs_in_both = (
225-
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
226-
)
227-
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
228-
end
229-
230161
"""
231162
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
232163

src/default_accumulators.jl

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -166,119 +166,12 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
166166
return acclogp(acc, Distributions.loglikelihood(right, left))
167167
end
168168

169-
"""
170-
VariableOrderAccumulator{T} <: AbstractAccumulator
171-
172-
An accumulator that tracks the order of variables in a `VarInfo`.
173-
174-
This doesn't track the full ordering, but rather how many observations have taken place
175-
before the assume statement for each variable. This is needed for particle methods, where
176-
the model is segmented into parts by each observation, and we need to know which part each
177-
assume statement is in.
178-
179-
# Fields
180-
$(TYPEDFIELDS)
181-
"""
182-
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
183-
"the number of observations"
184-
num_produce::Eltype
185-
"mapping of variable names to their order in the model"
186-
order::Dict{VNType,Eltype}
187-
end
188-
189-
"""
190-
VariableOrderAccumulator{T<:Integer}(n=zero(T))
191-
192-
Create a new `VariableOrderAccumulator` with the number of observations set to `n`.
193-
"""
194-
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
195-
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
196-
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
197-
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
198-
199-
function Base.copy(acc::VariableOrderAccumulator)
200-
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
201-
end
202-
203-
function Base.show(io::IO, acc::VariableOrderAccumulator)
204-
return print(
205-
io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))"
206-
)
207-
end
208-
209-
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
210-
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
211-
end
212-
213-
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
214-
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
215-
end
216-
217-
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
218-
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
219-
end
220-
221-
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder
222-
223-
split(acc::VariableOrderAccumulator) = copy(acc)
224-
225-
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
226-
# Note that assumptions are not allowed in parallelised blocks, and thus the
227-
# dictionaries should be identical.
228-
return VariableOrderAccumulator(
229-
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
230-
)
231-
end
232-
233-
function increment(acc::VariableOrderAccumulator)
234-
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
235-
end
236-
237-
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
238-
acc.order[vn] = acc.num_produce
239-
return acc
240-
end
241-
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)
242-
243-
function Base.convert(
244-
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
245-
) where {ElType,VnType}
246-
order = Dict{VnType,ElType}()
247-
for (k, v) in acc.order
248-
order[convert(VnType, k)] = convert(ElType, v)
249-
end
250-
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
251-
end
252-
253-
# TODO(mhauru)
254-
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
255-
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
256-
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
257-
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
258-
259169
function default_accumulators(
260170
::Type{FloatT}=LogProbType, ::Type{IntT}=Int
261171
) where {FloatT,IntT}
262172
return AccumulatorTuple(
263173
LogPriorAccumulator{FloatT}(),
264174
LogJacobianAccumulator{FloatT}(),
265175
LogLikelihoodAccumulator{FloatT}(),
266-
VariableOrderAccumulator{IntT}(),
267176
)
268177
end
269-
270-
function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName})
271-
order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order)
272-
return VariableOrderAccumulator(acc.num_produce, order)
273-
end
274-
275-
"""
276-
merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
277-
278-
Merge two `VariableOrderAccumulator` instances.
279-
280-
The `num_produce` field of the return value is the `num_produce` of `acc2`.
281-
"""
282-
function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
283-
return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order))
284-
end

src/simple_varinfo.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ Evaluation in transformed space of course also works:
122122
123123
```jldoctest simplevarinfo-general
124124
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
125-
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}())))
125+
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))
126126
127127
julia> # (✓) Positive probability mass on negative numbers!
128128
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
129129
-1.3678794411714423
130130
131131
julia> # While if we forget to indicate that it's transformed:
132132
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
133-
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}())))
133+
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))
134134
135135
julia> # (✓) No probability mass on negative numbers!
136136
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
@@ -418,7 +418,7 @@ Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V
418418
# `subset`
419419
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
420420
return SimpleVarInfo(
421-
_subset(varinfo.values, vns), subset(getaccs(varinfo), vns), varinfo.transformation
421+
_subset(varinfo.values, vns), map(copy, getaccs(varinfo)), varinfo.transformation
422422
)
423423
end
424424

@@ -456,7 +456,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns)
456456
# `merge`
457457
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
458458
values = merge(varinfo_left.values, varinfo_right.values)
459-
accs = merge(getaccs(varinfo_left), getaccs(varinfo_right))
459+
accs = map(copy, getaccs(varinfo_right))
460460
transformation = merge_transformations(
461461
varinfo_left.transformation, varinfo_right.transformation
462462
)

0 commit comments

Comments
 (0)