Skip to content
32 changes: 26 additions & 6 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,40 @@ Their semantics are the same as in Julia's `isapprox`; two values are equal if t
You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.
Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`).

### Removal of `PriorContext` and `LikelihoodContext`

A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`.
Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below.
### Evaluating model log-probabilities in more detail

Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field.
`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively.

Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below).
In this version, we have overhauled this quite substantially.
The technical details of exactly _how_ this is done is covered in the 'Accumulators' section below, but the upshot is that the log prior, log likelihood, and log Jacobian terms (for any linked variables) are separately tracked.

Specifically, you will want to use the following functions to access these log probabilities:

- `getlogprior(varinfo)` to get the log prior. **Note:** This version introduces new, more consistent behaviour for this function, in that it always returns the log-prior of the values in the original, untransformed space, even if the `varinfo` has been linked.
- `getloglikelihood(varinfo)` to get the log likelihood.
- `getlogjoint(varinfo)` to get the log joint probability. **Note:** Similar to `getlogprior`, this function now always returns the log joint of the values in the original, untransformed space, even if the `varinfo` has been linked.

If you are using linked VarInfos (e.g. if you are writing a sampler), you may find that you need to obtain the log probability of the variables in the transformed space.
To this end, you can use:

- `getlogjac(varinfo)` to get the log Jacobian of the link transforms for any linked variables.
- `getlogprior_internal(varinfo)` to get the log prior of the variables in the transformed space.
- `getlogjoint_internal(varinfo)` to get the log joint probability of the variables in the transformed space.

Since transformations only apply to random variables, the likelihood is unaffected by linking.

### Removal of `PriorContext` and `LikelihoodContext`

Following on from the above, a number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`.
Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below.

If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood).

If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`.
`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want.
Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`.
Thus, if you pass `getlogprior_internal` as the value of this parameter, you will get the same behaviour as with `PriorContext`.
(You should consider whether your use case needs the log prior in the transformed space, or the original space, and use (respectively) `getlogprior_internal` or `getlogprior` as needed.)

The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior.
Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`.
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ makedocs(;
sitename="DynamicPPL",
# The API index.html page is fairly large, and violates the default HTML page size
# threshold of 200KiB, so we double that.
format=Documenter.HTML(; size_threshold=2^10 * 400),
format=Documenter.HTML(;
size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3()
),
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
pages=[
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators.

```@docs
LogPriorAccumulator
LogJacobianAccumulator
LogLikelihoodAccumulator
VariableOrderAccumulator
```
Expand All @@ -380,7 +381,12 @@ getlogp
setlogp!!
acclogp!!
getlogjoint
getlogjoint_internal
getlogjac
setlogjac!!
acclogjac!!
getlogprior
getlogprior_internal
setlogprior!!
acclogprior!!
getloglikelihood
Expand Down
6 changes: 6 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export AbstractVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
LogJacobianAccumulator,
VariableOrderAccumulator,
push!!,
empty!!,
Expand All @@ -58,10 +59,15 @@ export AbstractVarInfo,
getlogjoint,
getlogprior,
getloglikelihood,
getlogjac,
getlogjoint_internal,
getlogprior_internal,
setlogp!!,
setlogprior!!,
setlogjac!!,
setloglikelihood!!,
acclogp!!,
acclogjac!!,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
Expand Down
102 changes: 90 additions & 12 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
"""
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi)

"""
getlogjoint_internal(vi::AbstractVarInfo)

Return the log of the joint probability of the observed data and parameters as
they are stored internally in `vi`, including the log-Jacobian for any linked
parameters.

In general, we have that:

```julia
getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi)
```
"""
getlogjoint_internal(vi::AbstractVarInfo) =
getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi)

"""
getlogp(vi::AbstractVarInfo)

Return a NamedTuple of the log prior and log likelihood probabilities.
Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities.

The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
error will be thrown.
The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them
are not present in `vi` an error will be thrown.
"""
function getlogp(vi::AbstractVarInfo)
return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi))
return (;
logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi)
)
end

"""
Expand Down Expand Up @@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@
"""
getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp

"""
getlogprior_internal(vi::AbstractVarInfo)

Return the log of the prior probability of the parameters as stored internally
in `vi`. This includes the log-Jacobian for any linked parameters.

In general, we have that:

```julia
getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi)
```
"""
getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi)

"""
getlogjac(vi::AbstractVarInfo)

Return the accumulated log-Jacobian term for any linked parameters in `vi`. The
Jacobian here is taken with respect to the forward (link) transform.

See also: [`setlogjac!!`](@ref).
"""
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logJ

"""
getloglikelihood(vi::AbstractVarInfo)

Expand Down Expand Up @@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
"""
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))

"""
setlogjac!!(vi::AbstractVarInfo, logJ)

Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
Jacobian here is taken with respect to the forward (link) transform.

See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref).
"""
setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ))

"""
setloglikelihood!!(vi::AbstractVarInfo, logp)

Expand All @@ -215,18 +267,21 @@ Set both the log prior and the log likelihood probabilities in `vi`.
See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
"""
function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior))
error("logp must have the fields logprior and loglikelihood and no other fields.")
if Set(names) != Set([:logprior, :logjac, :loglikelihood])
error(
"The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.",
)
end
vi = setlogprior!!(vi, logp.logprior)
vi = setlogjac!!(vi, logp.logjac)
vi = setloglikelihood!!(vi, logp.loglikelihood)
return vi
end

function setlogp!!(vi::AbstractVarInfo, logp::Number)
return error("""
`setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use
`setloglikelihood!!` and/or `setlogprior!!` instead.
`setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead.
""")
end

Expand Down Expand Up @@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp)
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
end

"""
acclogjac!!(vi::AbstractVarInfo, logJ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit unsure about the name logJ, which isn't snake_case. Do you have a reason to prefer it over logjac?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matches the maths notation, and is consistent with Bijectors.jl:

pysm@ati:~/ppl/bi (main) $ rg logJ | wc -l
      43

I don't know to what extent snake_case matters for things that are mathematical variables.

If all the float-accumulators are unified (and presumably the field will be called logp or similar), will this still be a problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the unification I'm introducing a function called logp, so the field name would remain. Happy to be consistent with Bijectors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do realise that in the rest of DynamicPPL we use logjac though, so for library-internal-consistency's sake we should change it. I'll do a big sed later.


Add `logJ` to the value of the log Jacobian in `vi`.

See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
"""
function acclogjac!!(vi::AbstractVarInfo, logJ)
return map_accumulator!!(
acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian)
)
end

"""
accloglikelihood!!(vi::AbstractVarInfo, logp)

Expand Down Expand Up @@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo)
if hasacc(vi, Val(:LogPrior))
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
end
if hasacc(vi, Val(:LogJacobian))
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
end
if hasacc(vi, Val(:LogLikelihood))
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
end
Expand Down Expand Up @@ -836,8 +907,11 @@ function link!!(
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)

lp_new = getlogprior(vi) - logjac
vi_new = setlogprior!!(unflatten(vi, y), lp_new)
# Set parameters
vi_new = unflatten(vi, y)
# Update logjac. We can overwrite any old value since there is only
# a single logjac term to worry about.
vi_new = setlogjac!!(vi_new, logjac)
return settrans!!(vi_new, t)
end

Expand All @@ -846,10 +920,14 @@ function invlink!!(
)
b = t.bijector
y = vi[:]
x, logjac = with_logabsdet_jacobian(b, y)
x = b(y)

lp_new = getlogprior(vi) + logjac
vi_new = setlogprior!!(unflatten(vi, x), lp_new)
# Set parameters
vi_new = unflatten(vi, x)
# Reset logjac to 0.
if hasacc(vi_new, Val(:LogJacobian))
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a chance that some mix-up of using a different invlink transform than what was originally used for linking would cause the logjac to actually be non-zero? Or would that always imply that quite a serious error has been made and we have no need to have well-defined behaviour?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, true, I guess somebody could do something horrific and use different StaticTransformations, in which case we should sum the logjac terms and add / subtract them as necessary.

end
return settrans!!(vi_new, NoTransformation())
end

Expand Down
15 changes: 13 additions & 2 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,21 @@ seen so far.

An accumulator type `T <: AbstractAccumulator` must implement the following methods:
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, right, left, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
- `accumulate_observe!!(acc::T, dist, val, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
- `Base.copy(acc::T)`

In these functions:
- `val` is the new value of the random variable sampled from a new distribution (always
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the distribution new?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of a typo 😅

in the original unlinked space), or the value on the left-hand side of an observe
statement.
- `dist` is the distribution on the RHS of the tilde statement.
- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the
tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`.
- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the
variable is stored as a linked value in the VarInfo. If the variable is stored in its
original, unlinked form, then `logjac` is zero.

To be able to work with multi-threading, it should also implement:
- `split(acc::T)`
- `combine(acc::T, acc2::T)`
Expand Down
6 changes: 3 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ end
function assume(dist::Distribution, vn::VarName, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
x, inv_logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist)
return x, vi
end

Expand Down Expand Up @@ -166,6 +166,6 @@ function assume(

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
vi = accumulate_assume!!(vi, r, logjac, vn, dist)
return r, vi
end
Loading
Loading