Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ acclogprior!!
getloglikelihood
setloglikelihood!!
accloglikelihood!!
resetlogp!!
```

#### Variables and their realizations
Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ export AbstractVarInfo,
acclogjac!!,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
is_flagged,
set_flag!,
unset_flag!,
Expand Down
29 changes: 10 additions & 19 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ function map_accumulators!!(func::Function, vi::AbstractVarInfo)
return setaccs!!(vi, map(func, getaccs(vi)))
end

"""
resetaccs!!(vi::AbstractVarInfo)

Reset the values of all accumulators, using [`reset`](@ref).
"""
function resetaccs!!(vi::AbstractVarInfo)
return setaccs!!(vi, map(reset, getaccs(vi)))
end

"""
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname}

Expand Down Expand Up @@ -423,24 +432,6 @@ function acclogp!!(vi::AbstractVarInfo, logp::Number)
return accloglikelihood!!(vi, logp)
end

"""
resetlogp!!(vi::AbstractVarInfo)

Reset the values of the log probabilities (prior and likelihood) in `vi` to zero.
"""
function resetlogp!!(vi::AbstractVarInfo)
if hasacc(vi, Val(:LogPrior))
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
end
if hasacc(vi, Val(:LogJacobian))
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
end
if hasacc(vi, Val(:LogLikelihood))
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
end
return vi
end

# Variables and their realizations.
@doc """
keys(vi::AbstractVarInfo)
Expand Down Expand Up @@ -491,7 +482,7 @@ function getindex_internal end
@doc """
empty!!(vi::AbstractVarInfo)

Empty `vi` of variables and reset any `logp` accumulators zeros.
Empty `vi` of variables and reset all accumulators.

This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
""" BangBang.empty!!
Expand Down
42 changes: 30 additions & 12 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, dist, val, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
- `reset(acc::T)`
- `Base.copy(acc::T)`

In these functions:
Expand Down Expand Up @@ -50,9 +51,7 @@ accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc))

Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`.

`vn` is the name of the variable being observed, `left` is the value of the variable, and
`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case
of literal observations like `0.0 ~ Normal()`.
See [`AbstractAccumulator`](@ref) for the meaning of the arguments.

`accumulate_observe!!` may mutate `acc`, but not any of the other arguments.

Expand All @@ -65,26 +64,45 @@ function accumulate_observe!! end

Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`.

`vn` is the name of the variable being assumed, `val` is the value of the variable (in the
original, unlinked space), and `right` is the distribution on the RHS of the tilde
statement. `logjac` is the log determinant of the Jacobian of the transformation that was
done to convert the value of `vn` as it was given to `val`: for example, if the sampler is
operating in linked (Euclidean) space, then logjac will be nonzero.
See [`AbstractAccumulator`](@ref) for the meaning of the arguments.

`accumulate_assume!!` may mutate `acc`, but not any of the other arguments.

See also: [`accumulate_observe!!`](@ref)
"""
function accumulate_assume!! end

"""
reset(acc::AbstractAccumulator)

Return a new accumulator like `acc`, but with its contents reset to the state that they
should be at the beginning of model evaluation.

Note that this may in general have very similar behaviour to [`split`](@ref), and may share
the same implementation, but the difference is that `split` may in principle happen at any
stage during model evaluation, whereas `reset` is only called at the beginning of model
evaluation.
"""
function reset end

@doc """
Base.copy(acc::AbstractAccumulator)

Create a new accumulator that is a copy of `acc`, without aliasing (i.e., this should
behave conceptually like a `deepcopy`).
""" Base.copy

"""
split(acc::AbstractAccumulator)

Return a new accumulator like `acc` but empty.
Return a new accumulator like `acc` suitable for use in a forked thread.

The returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is
used in the context of multi-threading where different threads may accumulate independently
and the results are then combined.

The precise meaning of "empty" is that that the returned value should be such that
`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading
where different threads may accumulate independently and the results are then combined.
Note that this may in general have very similar behaviour to [`reset`](@ref), but is
semantically different. See [`reset`](@ref) for more details.

See also: [`combine`](@ref)
"""
Expand Down
14 changes: 12 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,22 @@ end
const _DEBUG_ACC_NAME = :Debug
DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME

function split(acc::DebugAccumulator)
function Base.:(==)(acc1::DebugAccumulator, acc2::DebugAccumulator)
return (
acc1.varnames_seen == acc2.varnames_seen &&
acc1.statements == acc2.statements &&
acc1.error_on_failure == acc2.error_on_failure
)
end

function _zero(acc::DebugAccumulator)
return DebugAccumulator(
OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure
)
end
function combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
DynamicPPL.reset(acc::DebugAccumulator) = _zero(acc)
DynamicPPL.split(acc::DebugAccumulator) = _zero(acc)
function DynamicPPL.combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
return DebugAccumulator(
merge(acc1.varnames_seen, acc2.varnames_seen),
vcat(acc1.statements, acc2.statements),
Expand Down
7 changes: 3 additions & 4 deletions src/default_accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ end

Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h)

split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T))

_zero(::Tacc) where {Tlogp,Tacc<:LogProbAccumulator{Tlogp}} = Tacc(zero(Tlogp))
reset(acc::LogProbAccumulator) = _zero(acc)
split(acc::LogProbAccumulator) = _zero(acc)
function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator)
if basetypeof(acc) !== basetypeof(acc2)
msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))"
Expand All @@ -59,8 +60,6 @@ end

acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val)

Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc)))

function Base.convert(
::Type{AccType}, acc::LogProbAccumulator
) where {T,AccType<:LogProbAccumulator{T}}
Expand Down
8 changes: 7 additions & 1 deletion src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ end

accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator

split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
function Base.:(==)(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
return acc1.priors == acc2.priors
end

_zero(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
reset(acc::PriorDistributionAccumulator) = _zero(acc)
split(acc::PriorDistributionAccumulator) = _zero(acc)
function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors))
end
Expand Down
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
See also: [`evaluate_threadsafe!!`](@ref)
"""
function evaluate_threadunsafe!!(model, varinfo)
return _evaluate!!(model, resetlogp!!(varinfo))
return _evaluate!!(model, resetaccs!!(varinfo))
end

"""
Expand All @@ -899,7 +899,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
See also: [`evaluate_threadunsafe!!`](@ref)
"""
function evaluate_threadsafe!!(model, varinfo)
wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo))
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
result, wrapper_new = _evaluate!!(model, wrapper)
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
# will return the underlying VI, which is a bit counterintuitive (because
Expand Down
11 changes: 9 additions & 2 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
end

function Base.:(==)(
acc1::PointwiseLogProbAccumulator{wlp1}, acc2::PointwiseLogProbAccumulator{wlp2}
) where {wlp1,wlp2}
return (wlp1 == wlp2 && acc1.logps == acc2.logps)
end

Comment on lines +35 to +40
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the test that I wrote (checking that all the accs are == reset(acc) originally broke because most of them don't have an appropriate == method defined, so I had to introduce these (which mimic other changes we've made previously). I wonder if there's a better way of doing this than all the boilerplate that I introduced here though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a better way of doing this than all the boilerplate that I introduced here though.

I suspect not, not without some sketchy metaprogramming or doing a loop over fieldnames or something else that feels dangerous.

function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
end
Expand All @@ -56,10 +62,11 @@ function accumulator_name(
return Symbol("PointwiseLogProbAccumulator{$whichlogprob}")
end

function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps))
end

reset(acc::PointwiseLogProbAccumulator) = _zero(acc)
split(acc::PointwiseLogProbAccumulator) = _zero(acc)
function combine(
acc::PointwiseLogProbAccumulator{whichlogprob},
acc2::PointwiseLogProbAccumulator{whichlogprob},
Expand Down
2 changes: 1 addition & 1 deletion src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector)
end

function BangBang.empty!!(vi::SimpleVarInfo)
return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values))
return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values))
end
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)

Expand Down
22 changes: 4 additions & 18 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,13 @@ end

isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
function BangBang.empty!!(vi::ThreadSafeVarInfo)
return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo)))
return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo)))
end

function resetlogp!!(vi::ThreadSafeVarInfo)
vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo)
function resetaccs!!(vi::ThreadSafeVarInfo)
vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo)
for i in eachindex(vi.accs_by_thread)
if hasacc(vi, Val(:LogPrior))
vi.accs_by_thread[i] = map_accumulator(
zero, vi.accs_by_thread[i], Val(:LogPrior)
)
end
if hasacc(vi, Val(:LogJacobian))
vi.accs_by_thread[i] = map_accumulator(
zero, vi.accs_by_thread[i], Val(:LogJacobian)
)
end
if hasacc(vi, Val(:LogLikelihood))
vi.accs_by_thread[i] = map_accumulator(
zero, vi.accs_by_thread[i], Val(:LogLikelihood)
)
end
vi.accs_by_thread[i] = map(reset, vi.accs_by_thread[i])
end
return vi
end
Expand Down
8 changes: 7 additions & 1 deletion src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
end

function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
return (acc1.include_colon_eq == acc2.include_colon_eq && acc1.values == acc2.values)
end

function Base.copy(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq)
end

accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

function split(acc::ValuesAsInModelAccumulator)
function _zero(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
Comment on lines +33 to 35
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I preferred using _zero over Base.zero because the latter would effectively make this method public, and it doesn't seem like there's any reason to, since the public interface is split and reset.

reset(acc::ValuesAsInModelAccumulator) = _zero(acc)
split(acc::ValuesAsInModelAccumulator) = _zero(acc)
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
if acc1.include_colon_eq != acc2.include_colon_eq
msg = "Cannot combine accumulators with different include_colon_eq values."
Expand Down
2 changes: 1 addition & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ end

function BangBang.empty!!(vi::VarInfo)
_empty!(vi.metadata)
vi = resetlogp!!(vi)
vi = resetaccs!!(vi)
return vi
end

Expand Down
1 change: 1 addition & 0 deletions test/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using DynamicPPL:
convert_eltype,
getacc,
map_accumulator,
reset,
setacc!!,
split

Expand Down
2 changes: 1 addition & 1 deletion test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@
end

# Reset the logp accumulators.
svi_eval = DynamicPPL.resetlogp!!(svi_eval)
svi_eval = DynamicPPL.resetaccs!!(svi_eval)

# Compute `logjoint` using the varinfo.
logπ = logjoint(model, svi_eval)
Expand Down
2 changes: 1 addition & 1 deletion test/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@test getlogjoint(vi) == lp
@test getlogjoint(threadsafe_vi) == lp + 42

threadsafe_vi = resetlogp!!(threadsafe_vi)
threadsafe_vi = DynamicPPL.resetaccs!!(threadsafe_vi)
@test iszero(getlogjoint(threadsafe_vi))
expected_accs = DynamicPPL.AccumulatorTuple(
(DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))...
Expand Down
Loading
Loading