Skip to content
4 changes: 3 additions & 1 deletion benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ makedocs(;
sitename="DynamicPPL",
# The API index.html page is fairly large, and violates the default HTML page size
# threshold of 200KiB, so we double that.
format=Documenter.HTML(; size_threshold=2^10 * 400),
format=Documenter.HTML(;
size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3()
),
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
pages=[
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators.

```@docs
LogPriorAccumulator
LogJacobianAccumulator
LogLikelihoodAccumulator
VariableOrderAccumulator
```
Expand All @@ -380,7 +381,12 @@ getlogp
setlogp!!
acclogp!!
getlogjoint
getlogjoint_internal
getlogjac
setlogjac!!
acclogjac!!
getlogprior
getlogprior_internal
setlogprior!!
acclogprior!!
getloglikelihood
Expand Down
6 changes: 6 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export AbstractVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
LogJacobianAccumulator,
VariableOrderAccumulator,
push!!,
empty!!,
Expand All @@ -58,10 +59,15 @@ export AbstractVarInfo,
getlogjoint,
getlogprior,
getloglikelihood,
getlogjac,
getlogjoint_internal,
getlogprior_internal,
setlogp!!,
setlogprior!!,
setlogjac!!,
setloglikelihood!!,
acclogp!!,
acclogjac!!,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
Expand Down
102 changes: 90 additions & 12 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
"""
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi)

"""
getlogjoint_internal(vi::AbstractVarInfo)

Return the log of the joint probability of the observed data and parameters as
they are stored internally in `vi`, including the log-Jacobian for any linked
parameters.

In general, we have that:

```julia
getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi)
```
"""
getlogjoint_internal(vi::AbstractVarInfo) =
getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi)

"""
getlogp(vi::AbstractVarInfo)

Return a NamedTuple of the log prior and log likelihood probabilities.
Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities.

The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
error will be thrown.
The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them
are not present in `vi` an error will be thrown.
"""
function getlogp(vi::AbstractVarInfo)
return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi))
return (;
logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi)
)
end

"""
Expand Down Expand Up @@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@
"""
getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp

"""
getlogprior_internal(vi::AbstractVarInfo)

Return the log of the prior probability of the parameters as stored internally
in `vi`. This includes the log-Jacobian for any linked parameters.

In general, we have that:

```julia
getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi)
```
"""
getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi)

"""
getlogjac(vi::AbstractVarInfo)

Return the accumulated log-Jacobian term for any linked parameters in `vi`. The
Jacobian here is taken with respect to the forward (link) transform.

See also: [`setlogjac!!`](@ref).
"""
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logJ

"""
getloglikelihood(vi::AbstractVarInfo)

Expand Down Expand Up @@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
"""
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))

"""
setlogjac!!(vi::AbstractVarInfo, logJ)

Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
Jacobian here is taken with respect to the forward (link) transform.

See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref).
"""
setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ))

"""
setloglikelihood!!(vi::AbstractVarInfo, logp)

Expand All @@ -215,18 +267,21 @@ Set both the log prior and the log likelihood probabilities in `vi`.
See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
"""
function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior))
error("logp must have the fields logprior and loglikelihood and no other fields.")
if Set(names) != Set([:logprior, :logjac, :loglikelihood])
error(
"The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.",
)
end
vi = setlogprior!!(vi, logp.logprior)
vi = setlogjac!!(vi, logp.logjac)
vi = setloglikelihood!!(vi, logp.loglikelihood)
return vi
end

function setlogp!!(vi::AbstractVarInfo, logp::Number)
return error("""
`setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use
`setloglikelihood!!` and/or `setlogprior!!` instead.
`setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead.
""")
end

Expand Down Expand Up @@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp)
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
end

"""
acclogjac!!(vi::AbstractVarInfo, logJ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit unsure about the name logJ, which isn't snake_case. Do you have a reason to prefer it over logjac?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matches the maths notation, and is consistent with Bijectors.jl:

pysm@ati:~/ppl/bi (main) $ rg logJ | wc -l
      43

I don't know to what extent snake_case matters for things that are mathematical variables.

If all the float-accumulators are unified (and presumably the field will be called logp or similar), will this still be a problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the unification I'm introducing a function called logp, so the field name would remain. Happy to be consistent with Bijectors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do realise that in the rest of DynamicPPL we use logjac though, so for library-internal-consistency's sake we should change it. I'll do a big sed later.


Add `logJ` to the value of the log Jacobian in `vi`.

See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
"""
function acclogjac!!(vi::AbstractVarInfo, logJ)
return map_accumulator!!(
acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian)
)
end

"""
accloglikelihood!!(vi::AbstractVarInfo, logp)

Expand Down Expand Up @@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo)
if hasacc(vi, Val(:LogPrior))
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
end
if hasacc(vi, Val(:LogJacobian))
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
end
if hasacc(vi, Val(:LogLikelihood))
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
end
Expand Down Expand Up @@ -836,8 +907,11 @@ function link!!(
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)

lp_new = getlogprior(vi) - logjac
vi_new = setlogprior!!(unflatten(vi, y), lp_new)
# Set parameters
vi_new = unflatten(vi, y)
# Update logjac. We can overwrite any old value since there is only
# a single logjac term to worry about.
vi_new = setlogjac!!(vi_new, logjac)
return settrans!!(vi_new, t)
end

Expand All @@ -846,10 +920,14 @@ function invlink!!(
)
b = t.bijector
y = vi[:]
x, logjac = with_logabsdet_jacobian(b, y)
x = b(y)

lp_new = getlogprior(vi) + logjac
vi_new = setlogprior!!(unflatten(vi, x), lp_new)
# Set parameters
vi_new = unflatten(vi, x)
# Reset logjac to 0.
if hasacc(vi_new, Val(:LogJacobian))
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a chance that some mix-up of using a different invlink transform than what was originally used for linking would cause the logjac to actually be non-zero? Or would that always imply that quite a serious error has been made and we have no need to have well-defined behaviour?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, true, I guess somebody could do something horrific and use different StaticTransformations, in which case we should sum the logjac terms and add / subtract them as necessary.

end
return settrans!!(vi_new, NoTransformation())
end

Expand Down
15 changes: 13 additions & 2 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,21 @@ seen so far.

An accumulator type `T <: AbstractAccumulator` must implement the following methods:
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, right, left, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
- `accumulate_observe!!(acc::T, dist, val, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
- `Base.copy(acc::T)`

In these functions:
- `val` is the new value of the random variable sampled from a new distribution (always
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the distribution new?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of a typo 😅

in the original unlinked space), or the value on the left-hand side of an observe
statement.
- `dist` is the distribution on the RHS of the tilde statement.
- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the
tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`.
- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the
variable is stored as a linked value in the VarInfo. If the variable is stored in its
original, unlinked form, then `logjac` is zero.

To be able to work with multi-threading, it should also implement:
- `split(acc::T)`
- `combine(acc::T, acc2::T)`
Expand Down
6 changes: 3 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ end
function assume(dist::Distribution, vn::VarName, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
x, inv_logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist)
return x, vi
end

Expand Down Expand Up @@ -166,6 +166,6 @@ function assume(

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
vi = accumulate_assume!!(vi, r, logjac, vn, dist)
return r, vi
end
Loading
Loading