Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ acclogprior!!
getloglikelihood
setloglikelihood!!
accloglikelihood!!
resetlogp!!
```

#### Variables and their realizations
Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ export AbstractVarInfo,
acclogjac!!,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
is_flagged,
set_flag!,
unset_flag!,
Expand Down
29 changes: 10 additions & 19 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ function map_accumulators!!(func::Function, vi::AbstractVarInfo)
return setaccs!!(vi, map(func, getaccs(vi)))
end

"""
resetaccs!!(vi::AbstractVarInfo)

Reset the values of all accumulators, using [`reset`](@ref).
"""
function resetaccs!!(vi::AbstractVarInfo)
return setaccs!!(vi, map(reset, getaccs(vi)))
end

"""
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname}

Expand Down Expand Up @@ -423,24 +432,6 @@ function acclogp!!(vi::AbstractVarInfo, logp::Number)
return accloglikelihood!!(vi, logp)
end

"""
resetlogp!!(vi::AbstractVarInfo)

Reset the values of the log probabilities (prior and likelihood) in `vi` to zero.
"""
function resetlogp!!(vi::AbstractVarInfo)
if hasacc(vi, Val(:LogPrior))
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
end
if hasacc(vi, Val(:LogJacobian))
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
end
if hasacc(vi, Val(:LogLikelihood))
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
end
return vi
end

# Variables and their realizations.
@doc """
keys(vi::AbstractVarInfo)
Expand Down Expand Up @@ -491,7 +482,7 @@ function getindex_internal end
@doc """
empty!!(vi::AbstractVarInfo)

Empty `vi` of variables and reset any `logp` accumulators zeros.
Empty `vi` of variables and reset all accumulators.

This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
""" BangBang.empty!!
Expand Down
42 changes: 30 additions & 12 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, dist, val, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
- `reset(acc::T)`
- `Base.copy(acc::T)`

In these functions:
Expand Down Expand Up @@ -50,9 +51,7 @@ accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc))

Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`.

`vn` is the name of the variable being observed, `left` is the value of the variable, and
`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case
of literal observations like `0.0 ~ Normal()`.
See [`AbstractAccumulator`](@ref) for the meaning of the arguments.

`accumulate_observe!!` may mutate `acc`, but not any of the other arguments.

Expand All @@ -65,26 +64,45 @@ function accumulate_observe!! end

Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`.

`vn` is the name of the variable being assumed, `val` is the value of the variable (in the
original, unlinked space), and `right` is the distribution on the RHS of the tilde
statement. `logjac` is the log determinant of the Jacobian of the transformation that was
done to convert the value of `vn` as it was given to `val`: for example, if the sampler is
operating in linked (Euclidean) space, then logjac will be nonzero.
See [`AbstractAccumulator`](@ref) for the meaning of the arguments.

`accumulate_assume!!` may mutate `acc`, but not any of the other arguments.

See also: [`accumulate_observe!!`](@ref)
"""
function accumulate_assume!! end

"""
reset(acc::AbstractAccumulator)

Return a new accumulator like `acc`, but with its contents reset to the state that they
should be at the beginning of model evaluation.

Note that this may in general have very similar behaviour to [`split`](@ref), and may share
the same implementation, but the difference is that `split` may in principle happen at any
stage during model evaluation, whereas `reset` is only called at the beginning of model
evaluation.
"""
function reset end

@doc """
Base.copy(acc::AbstractAccumulator)

Create a new accumulator that is a copy of `acc`, without aliasing (i.e., this should
behave conceptually like a `deepcopy`).
""" Base.copy

"""
split(acc::AbstractAccumulator)

Return a new accumulator like `acc` but empty.
Return a new accumulator like `acc` suitable for use in a forked thread.

The returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is
used in the context of multi-threading where different threads may accumulate independently
and the results are then combined.

The precise meaning of "empty" is that that the returned value should be such that
`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading
where different threads may accumulate independently and the results are then combined.
Note that this may in general have very similar behaviour to [`reset`](@ref), but is
semantically different. See [`reset`](@ref) for more details.

See also: [`combine`](@ref)
"""
Expand Down
16 changes: 13 additions & 3 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,22 @@ end
const _DEBUG_ACC_NAME = :Debug
DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME

function split(acc::DebugAccumulator)
function Base.:(==)(acc1::DebugAccumulator, acc2::DebugAccumulator)
return (
acc1.varnames_seen == acc2.varnames_seen &&
acc1.statements == acc2.statements &&
acc1.error_on_failure == acc2.error_on_failure
)
end

function _zero(acc::DebugAccumulator)
return DebugAccumulator(
OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure
)
end
function combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
DynamicPPL.reset(acc::DebugAccumulator) = _zero(acc)
DynamicPPL.split(acc::DebugAccumulator) = _zero(acc)
function DynamicPPL.combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
return DebugAccumulator(
merge(acc1.varnames_seen, acc2.varnames_seen),
vcat(acc1.statements, acc2.statements),
Expand Down Expand Up @@ -416,7 +426,7 @@ function check_model_and_trace(
issuccess = check_model_pre_evaluation(model)

# Force single-threaded execution.
DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot.


# Perform checks after evaluating the model.
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))
Expand Down
7 changes: 3 additions & 4 deletions src/default_accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ end

Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h)

split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T))

_zero(::Tacc) where {Tlogp,Tacc<:LogProbAccumulator{Tlogp}} = Tacc(zero(Tlogp))
reset(acc::LogProbAccumulator) = _zero(acc)
split(acc::LogProbAccumulator) = _zero(acc)
function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator)
if basetypeof(acc) !== basetypeof(acc2)
msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))"
Expand All @@ -59,8 +60,6 @@ end

acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val)

Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc)))

function Base.convert(
::Type{AccType}, acc::LogProbAccumulator
) where {T,AccType<:LogProbAccumulator{T}}
Expand Down
8 changes: 7 additions & 1 deletion src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ end

accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator

split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
function Base.:(==)(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
return acc1.priors == acc2.priors
end

_zero(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
reset(acc::PriorDistributionAccumulator) = _zero(acc)
split(acc::PriorDistributionAccumulator) = _zero(acc)
function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors))
end
Expand Down
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
See also: [`evaluate_threadsafe!!`](@ref)
"""
function evaluate_threadunsafe!!(model, varinfo)
return _evaluate!!(model, resetlogp!!(varinfo))
return _evaluate!!(model, resetaccs!!(varinfo))
end

"""
Expand All @@ -899,7 +899,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
See also: [`evaluate_threadunsafe!!`](@ref)
"""
function evaluate_threadsafe!!(model, varinfo)
wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo))
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
result, wrapper_new = _evaluate!!(model, wrapper)
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
# will return the underlying VI, which is a bit counterintuitive (because
Expand Down
29 changes: 25 additions & 4 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
end

function Base.:(==)(
acc1::PointwiseLogProbAccumulator{wlp1}, acc2::PointwiseLogProbAccumulator{wlp2}
) where {wlp1,wlp2}
return (wlp1 == wlp2 && acc1.logps == acc2.logps)
end

Comment on lines +35 to +40
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the test that I wrote (checking that all the accs are == reset(acc) originally broke because most of them don't have an appropriate == method defined, so I had to introduce these (which mimic other changes we've made previously). I wonder if there's a better way of doing this than all the boilerplate that I introduced here though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a better way of doing this than all the boilerplate that I introduced here though.

I suspect not, not without some sketchy metaprogramming or doing a loop over fieldnames or something else that feels dangerous.

function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
end
Expand All @@ -56,10 +62,11 @@ function accumulator_name(
return Symbol("PointwiseLogProbAccumulator{$whichlogprob}")
end

function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps))
end

reset(acc::PointwiseLogProbAccumulator) = _zero(acc)
split(acc::PointwiseLogProbAccumulator) = _zero(acc)
function combine(
acc::PointwiseLogProbAccumulator{whichlogprob},
acc2::PointwiseLogProbAccumulator{whichlogprob},
Expand Down Expand Up @@ -223,23 +230,37 @@ function pointwise_logdensities(
# Get the data by executing the model once
vi = VarInfo(model)

# This accumulator tracks the pointwise log-probabilities in a single iteration.
AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType}
vi = setaccs!!(vi, (AccType(),))

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))

# Maintain a separate accumulator that isn't tied to a VarInfo but rather
# tracks _all_ iterations.
all_logps = AccType()
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)

# Execute model
vi = last(evaluate!!(model, vi))

# Get the log-probabilities
this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps

# Merge into main acc
for (varname, this_lp) in this_iter_logps
# Because `this_lp` is obtained from one model execution, it should only
# contain one variable, hence `only()`.
push!(all_logps, varname, only(this_lp))
end
end

logps = getacc(vi, Val(accumulator_name(AccType))).logps
niters = size(chain, 1)
nchains = size(chain, 3)
logdensities = OrderedDict(
varname => reshape(vals, niters, nchains) for (varname, vals) in logps
varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps
)
return logdensities
end
Expand Down
2 changes: 1 addition & 1 deletion src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector)
end

function BangBang.empty!!(vi::SimpleVarInfo)
return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values))
return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values))
end
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)

Expand Down
22 changes: 4 additions & 18 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,13 @@ end

isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
function BangBang.empty!!(vi::ThreadSafeVarInfo)
return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo)))
return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo)))
end

function resetlogp!!(vi::ThreadSafeVarInfo)
vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo)
function resetaccs!!(vi::ThreadSafeVarInfo)
vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo)
for i in eachindex(vi.accs_by_thread)
if hasacc(vi, Val(:LogPrior))
vi.accs_by_thread[i] = map_accumulator(
zero, vi.accs_by_thread[i], Val(:LogPrior)
)
end
if hasacc(vi, Val(:LogJacobian))
vi.accs_by_thread[i] = map_accumulator(
zero, vi.accs_by_thread[i], Val(:LogJacobian)
)
end
if hasacc(vi, Val(:LogLikelihood))
vi.accs_by_thread[i] = map_accumulator(
zero, vi.accs_by_thread[i], Val(:LogLikelihood)
)
end
vi.accs_by_thread[i] = map(reset, vi.accs_by_thread[i])
end
return vi
end
Expand Down
8 changes: 7 additions & 1 deletion src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
end

function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
return (acc1.include_colon_eq == acc2.include_colon_eq && acc1.values == acc2.values)
end

function Base.copy(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq)
end

accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

function split(acc::ValuesAsInModelAccumulator)
function _zero(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
Comment on lines +33 to 35
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I preferred using _zero over Base.zero because the latter would effectively make this method public, and it doesn't seem like there's any reason to, since the public interface is split and reset.

reset(acc::ValuesAsInModelAccumulator) = _zero(acc)
split(acc::ValuesAsInModelAccumulator) = _zero(acc)
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
if acc1.include_colon_eq != acc2.include_colon_eq
msg = "Cannot combine accumulators with different include_colon_eq values."
Expand Down
2 changes: 1 addition & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ end

function BangBang.empty!!(vi::VarInfo)
_empty!(vi.metadata)
vi = resetlogp!!(vi)
vi = resetaccs!!(vi)
return vi
end

Expand Down
Loading
Loading