-
Notifications
You must be signed in to change notification settings - Fork 36
Implement more consistent tracking of logp components via LogJacobianAccumulator
#998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
a29b953
d6c9cfa
e671a56
5a4b01b
974c282
60b6863
53a2f61
a47641c
10de51f
a19bf2f
48a6048
049b3d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). | |
""" | ||
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) | ||
|
||
""" | ||
getlogjoint_internal(vi::AbstractVarInfo) | ||
|
||
Return the log of the joint probability of the observed data and parameters as | ||
they are stored internally in `vi`, including the log-Jacobian for any linked | ||
parameters. | ||
|
||
In general, we have that: | ||
|
||
```julia | ||
getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi) | ||
``` | ||
""" | ||
getlogjoint_internal(vi::AbstractVarInfo) = | ||
getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi) | ||
|
||
""" | ||
getlogp(vi::AbstractVarInfo) | ||
|
||
Return a NamedTuple of the log prior and log likelihood probabilities. | ||
Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities. | ||
|
||
The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an | ||
error will be thrown. | ||
The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them | ||
are not present in `vi` an error will be thrown. | ||
""" | ||
function getlogp(vi::AbstractVarInfo) | ||
return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) | ||
return (; | ||
logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi) | ||
) | ||
end | ||
|
||
""" | ||
|
@@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ | |
""" | ||
getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp | ||
|
||
""" | ||
getlogprior_internal(vi::AbstractVarInfo) | ||
|
||
Return the log of the prior probability of the parameters as stored internally | ||
in `vi`. This includes the log-Jacobian for any linked parameters. | ||
|
||
In general, we have that: | ||
|
||
```julia | ||
getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi) | ||
``` | ||
""" | ||
getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi) | ||
|
||
""" | ||
getlogjac(vi::AbstractVarInfo) | ||
|
||
Return the accumulated log-Jacobian term for any linked parameters in `vi`. The | ||
Jacobian here is taken with respect to the forward (link) transform. | ||
|
||
See also: [`setlogjac!!`](@ref). | ||
""" | ||
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logJ | ||
|
||
""" | ||
getloglikelihood(vi::AbstractVarInfo) | ||
|
||
|
@@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re | |
""" | ||
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) | ||
|
||
""" | ||
setlogjac!!(vi::AbstractVarInfo, logJ) | ||
|
||
Set the accumulated log-Jacobian term for any linked parameters in `vi`. The | ||
Jacobian here is taken with respect to the forward (link) transform. | ||
|
||
See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref). | ||
""" | ||
setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ)) | ||
|
||
""" | ||
setloglikelihood!!(vi::AbstractVarInfo, logp) | ||
|
||
|
@@ -215,18 +267,21 @@ Set both the log prior and the log likelihood probabilities in `vi`. | |
See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). | ||
""" | ||
function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} | ||
if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) | ||
error("logp must have the fields logprior and loglikelihood and no other fields.") | ||
if Set(names) != Set([:logprior, :logjac, :loglikelihood]) | ||
error( | ||
"The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.", | ||
) | ||
end | ||
vi = setlogprior!!(vi, logp.logprior) | ||
vi = setlogjac!!(vi, logp.logjac) | ||
vi = setloglikelihood!!(vi, logp.loglikelihood) | ||
return vi | ||
end | ||
|
||
function setlogp!!(vi::AbstractVarInfo, logp::Number) | ||
return error(""" | ||
`setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use | ||
`setloglikelihood!!` and/or `setlogprior!!` instead. | ||
`setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. | ||
""") | ||
end | ||
|
||
|
@@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp) | |
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) | ||
end | ||
|
||
""" | ||
acclogjac!!(vi::AbstractVarInfo, logJ) | ||
|
||
|
||
Add `logJ` to the value of the log Jacobian in `vi`. | ||
|
||
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). | ||
""" | ||
function acclogjac!!(vi::AbstractVarInfo, logJ) | ||
return map_accumulator!!( | ||
acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian) | ||
) | ||
end | ||
|
||
""" | ||
accloglikelihood!!(vi::AbstractVarInfo, logp) | ||
|
||
|
@@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo) | |
if hasacc(vi, Val(:LogPrior)) | ||
vi = map_accumulator!!(zero, vi, Val(:LogPrior)) | ||
end | ||
if hasacc(vi, Val(:LogJacobian)) | ||
vi = map_accumulator!!(zero, vi, Val(:LogJacobian)) | ||
end | ||
if hasacc(vi, Val(:LogLikelihood)) | ||
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) | ||
end | ||
|
@@ -836,8 +907,11 @@ function link!!( | |
x = vi[:] | ||
y, logjac = with_logabsdet_jacobian(b, x) | ||
|
||
lp_new = getlogprior(vi) - logjac | ||
vi_new = setlogprior!!(unflatten(vi, y), lp_new) | ||
# Set parameters | ||
vi_new = unflatten(vi, y) | ||
# Update logjac. We can overwrite any old value since there is only | ||
# a single logjac term to worry about. | ||
vi_new = setlogjac!!(vi_new, logjac) | ||
return settrans!!(vi_new, t) | ||
end | ||
|
||
|
@@ -846,10 +920,14 @@ function invlink!!( | |
) | ||
b = t.bijector | ||
y = vi[:] | ||
x, logjac = with_logabsdet_jacobian(b, y) | ||
x = b(y) | ||
|
||
lp_new = getlogprior(vi) + logjac | ||
vi_new = setlogprior!!(unflatten(vi, x), lp_new) | ||
# Set parameters | ||
vi_new = unflatten(vi, x) | ||
# Reset logjac to 0. | ||
if hasacc(vi_new, Val(:LogJacobian)) | ||
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian)) | ||
|
||
end | ||
return settrans!!(vi_new, NoTransformation()) | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,21 @@ seen so far. | |
|
||
An accumulator type `T <: AbstractAccumulator` must implement the following methods: | ||
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` | ||
- `accumulate_observe!!(acc::T, right, left, vn)` | ||
- `accumulate_assume!!(acc::T, val, logjac, vn, right)` | ||
- `accumulate_observe!!(acc::T, dist, val, vn)` | ||
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)` | ||
- `Base.copy(acc::T)` | ||
|
||
In these functions: | ||
- `val` is the new value of the random variable sampled from a new distribution (always | ||
|
||
in the original unlinked space), or the value on the left-hand side of an observe | ||
statement. | ||
- `dist` is the distribution on the RHS of the tilde statement. | ||
- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the | ||
tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. | ||
- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the | ||
variable is stored as a linked value in the VarInfo. If the variable is stored in its | ||
original, unlinked form, then `logjac` is zero. | ||
|
||
To be able to work with multi-threading, it should also implement: | ||
- `split(acc::T)` | ||
- `combine(acc::T, acc2::T)` | ||
|
Uh oh!
There was an error while loading. Please reload this page.