Skip to content

Remove VariableOrderAccumulator and subset and merge on accs #1005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,35 @@ The other case where one might use `PriorContext` was to use `@addlogprob!` to a
Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`.
Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood.

### Removal of `order` and `num_produce`

The `VarInfo` type used to carry with it:

- `num_produce`, an integer which recorded the number of observe tilde-statements that had been evaluated so far; and
- `order`, an integer per `VarName` which recorded the value of `num_produce` at the time that the variable was seen.

These fields were used in particle samplers in Turing.jl.
In DynamicPPL 0.37, these fields and the associated functions have been removed:

- `get_num_produce`
- `set_num_produce!!`
- `reset_num_produce!!`
- `increment_num_produce!!`
- `set_retained_vns_del!`
- `setorder!!`

Because this is one of the more arcane features of DynamicPPL, some extra explanation is warranted.

`num_produce` and `order`, along with the `del` flag in `VarInfo`, were used to control whether new values for variables were sampled during model execution.
For example, the particle Gibbs method has a _reference particle_, for which variables are never resampled.
However, if the reference particle is _forked_ (i.e., if the reference particle is selected by a resampling step multiple times and thereby copied), then the variables that have not yet been evaluated must be sampled anew to ensure that the new particle is independent of the reference particle.

Previousy, this was accomplished by setting the `del` flag in the `VarInfo` object for all variables with `order` greater or equal to than `num_produce`.
Note that setting the `del` flag does not itself trigger a new value to be sampled; rather, it indicates that a new value should be sampled _if the variable is encountered again_.
[This Turing.jl PR](https://github.com/TuringLang/Turing.jl/pull/2629) changes the implementation to set the `del` flag for _all_ variables in the `VarInfo`.
Since the `del` flag only makes a difference when encountering a variable, this approach is entirely equivalent as long as the same variable is not seen multiple times in the model.
The interested reader is referred to that PR for more details.

**Internals**

### Accumulators
Expand All @@ -83,7 +112,6 @@ This release overhauls how VarInfo objects track variables such as the log joint
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.
Expand Down
12 changes: 0 additions & 12 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,17 +334,6 @@ unset_flag!
is_flagged
```

The following functions were used for sequential Monte Carlo methods.

```@docs
get_num_produce
set_num_produce!!
increment_num_produce!!
reset_num_produce!!
setorder!!
set_retained_vns_del!
```

```@docs
Base.empty!
```
Expand All @@ -369,7 +358,6 @@ DynamicPPL provides the following default accumulators.
LogPriorAccumulator
LogJacobianAccumulator
LogLikelihoodAccumulator
VariableOrderAccumulator
```

### Common API
Expand Down
7 changes: 0 additions & 7 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ export AbstractVarInfo,
LogLikelihoodAccumulator,
LogPriorAccumulator,
LogJacobianAccumulator,
VariableOrderAccumulator,
push!!,
empty!!,
subset,
Expand All @@ -72,15 +71,9 @@ export AbstractVarInfo,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
get_num_produce,
set_num_produce!!,
reset_num_produce!!,
increment_num_produce!!,
set_retained_vns_del!,
is_flagged,
set_flag!,
unset_flag!,
setorder!!,
istrans,
link,
link!!,
Expand Down
58 changes: 1 addition & 57 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,24 +441,6 @@ function resetlogp!!(vi::AbstractVarInfo)
return vi
end

"""
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)

Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
statements run before sampling `vn`.
"""
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
end

"""
getorder(vi::VarInfo, vn::VarName)

Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
run before sampling `vn`.
"""
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]

# Variables and their realizations.
@doc """
keys(vi::AbstractVarInfo)
Expand Down Expand Up @@ -509,8 +491,7 @@ function getindex_internal end
@doc """
empty!!(vi::AbstractVarInfo)

Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to
zeros.
Empty `vi` of variables and reset any `logp` accumulators zeros.

This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
""" BangBang.empty!!
Expand Down Expand Up @@ -1068,43 +1049,6 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
return x, logpdf(dist, x) + logjac
end

"""
get_num_produce(vi::AbstractVarInfo)

Return the `num_produce` of `vi`.
"""
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce

"""
set_num_produce!!(vi::AbstractVarInfo, n::Int)

Set the `num_produce` field of `vi` to `n`.
"""
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
if hasacc(vi, Val(:VariableOrder))
acc = getacc(vi, Val(:VariableOrder))
acc = VariableOrderAccumulator(n, acc.order)
else
acc = VariableOrderAccumulator(n)
end
return setacc!!(vi, acc)
end

"""
increment_num_produce!!(vi::AbstractVarInfo)

Add 1 to `num_produce` in `vi`.
"""
increment_num_produce!!(vi::AbstractVarInfo) =
map_accumulator!!(increment, vi, Val(:VariableOrder))

"""
reset_num_produce!!(vi::AbstractVarInfo)

Reset the value of `num_produce` in `vi` to 0.
"""
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))

"""
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])

Expand Down
69 changes: 0 additions & 69 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ To be able to work with multi-threading, it should also implement:
- `split(acc::T)`
- `combine(acc::T, acc2::T)`

If two accumulators of the same type should be merged in some non-trivial way, other than
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.

If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
do something other than copy the original accumulator, then
`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`

See the documentation for each of these functions for more details.
"""
abstract type AbstractAccumulator end
Expand Down Expand Up @@ -120,24 +113,6 @@ used by various AD backends, should implement a method for this function.
"""
convert_eltype(::Type, acc::AbstractAccumulator) = acc

"""
subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})

Return a new accumulator that only contains the information for the `VarName`s in `vns`.

By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
"""
subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc)

"""
merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)

Merge two accumulators of the same type. Returns a new accumulator of the same type.

By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
"""
Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2)

"""
AccumulatorTuple{N,T<:NamedTuple}

Expand Down Expand Up @@ -183,50 +158,6 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
return AccumulatorTuple(convert(T, accs.nt))
end

"""
subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})

Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
"""
function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt))
end

"""
_joint_keys(nt1::NamedTuple, nt2::NamedTuple)

A helper function that returns three tuples of keys given two `NamedTuple`s:
The keys only in `nt1`, only in `nt2`, and in both, and in that order.

Implemented as a generated function to enable constant propagation of the result in `merge`.
"""
@generated function _joint_keys(
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
) where {names1,names2}
only_in_nt1 = tuple(setdiff(names1, names2)...)
only_in_nt2 = tuple(setdiff(names2, names1)...)
in_both = tuple(intersect(names1, names2)...)
return :($only_in_nt1, $only_in_nt2, $in_both)
end

"""
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)

Merge two `AccumulatorTuple`s.

For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
accumulators themselves. Other accumulators are copied.
"""
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
accs_in_both = (
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
)
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
end

"""
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)

Expand Down
107 changes: 0 additions & 107 deletions src/default_accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,119 +166,12 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
return acclogp(acc, Distributions.loglikelihood(right, left))
end

"""
VariableOrderAccumulator{T} <: AbstractAccumulator

An accumulator that tracks the order of variables in a `VarInfo`.

This doesn't track the full ordering, but rather how many observations have taken place
before the assume statement for each variable. This is needed for particle methods, where
the model is segmented into parts by each observation, and we need to know which part each
assume statement is in.

# Fields
$(TYPEDFIELDS)
"""
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
"the number of observations"
num_produce::Eltype
"mapping of variable names to their order in the model"
order::Dict{VNType,Eltype}
end

"""
VariableOrderAccumulator{T<:Integer}(n=zero(T))

Create a new `VariableOrderAccumulator` with the number of observations set to `n`.
"""
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()

function Base.copy(acc::VariableOrderAccumulator)
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
end

function Base.show(io::IO, acc::VariableOrderAccumulator)
return print(
io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))"
)
end

function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
end

function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
end

function Base.hash(acc::VariableOrderAccumulator, h::UInt)
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
end

accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder

split(acc::VariableOrderAccumulator) = copy(acc)

function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
# Note that assumptions are not allowed in parallelised blocks, and thus the
# dictionaries should be identical.
return VariableOrderAccumulator(
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
)
end

function increment(acc::VariableOrderAccumulator)
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
end

function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
acc.order[vn] = acc.num_produce
return acc
end
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)

function Base.convert(
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
) where {ElType,VnType}
order = Dict{VnType,ElType}()
for (k, v) in acc.order
order[convert(VnType, k)] = convert(ElType, v)
end
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
end

# TODO(mhauru)
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.

function default_accumulators(
::Type{FloatT}=LogProbType, ::Type{IntT}=Int
) where {FloatT,IntT}
return AccumulatorTuple(
LogPriorAccumulator{FloatT}(),
LogJacobianAccumulator{FloatT}(),
LogLikelihoodAccumulator{FloatT}(),
VariableOrderAccumulator{IntT}(),
)
end

function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName})
order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order)
return VariableOrderAccumulator(acc.num_produce, order)
end

"""
merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)

Merge two `VariableOrderAccumulator` instances.

The `num_produce` field of the return value is the `num_produce` of `acc2`.
"""
function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order))
end
Loading
Loading