Skip to content

Commit e8c5a70

Browse files
committed
Merge remote-tracking branch 'origin/breaking' into mhauru/logprobacc
2 parents e852b5e + bd99d4f commit e8c5a70

File tree

4 files changed

+40
-26
lines changed

4 files changed

+40
-26
lines changed

src/abstract_varinfo.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ Jacobian here is taken with respect to the forward (link) transform.
204204
205205
See also: [`setlogjac!!`](@ref).
206206
"""
207-
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logJ
207+
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logjac
208208

209209
"""
210210
getloglikelihood(vi::AbstractVarInfo)
@@ -239,14 +239,14 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
239239
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))
240240

241241
"""
242-
setlogjac!!(vi::AbstractVarInfo, logJ)
242+
setlogjac!!(vi::AbstractVarInfo, logjac)
243243
244244
Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
245245
Jacobian here is taken with respect to the forward (link) transform.
246246
247247
See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref).
248248
"""
249-
setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ))
249+
setlogjac!!(vi::AbstractVarInfo, logjac) = setacc!!(vi, LogJacobianAccumulator(logjac))
250250

251251
"""
252252
setloglikelihood!!(vi::AbstractVarInfo, logp)
@@ -372,6 +372,19 @@ function acclogjac!!(vi::AbstractVarInfo, logJ)
372372
return map_accumulator!!(acc -> acclogp(acc, logJ), vi, Val(:LogJacobian))
373373
end
374374

375+
"""
376+
acclogjac!!(vi::AbstractVarInfo, logjac)
377+
378+
Add `logjac` to the value of the log Jacobian in `vi`.
379+
380+
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
381+
"""
382+
function acclogjac!!(vi::AbstractVarInfo, logjac)
383+
return map_accumulator!!(
384+
acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian)
385+
)
386+
end
387+
375388
"""
376389
accloglikelihood!!(vi::AbstractVarInfo, logp)
377390
@@ -903,28 +916,29 @@ function link!!(
903916
x = vi[:]
904917
y, logjac = with_logabsdet_jacobian(b, x)
905918

906-
# Set parameters
907-
vi_new = unflatten(vi, y)
908-
# Update logjac. We can overwrite any old value since there is only
909-
# a single logjac term to worry about.
910-
vi_new = setlogjac!!(vi_new, logjac)
911-
return settrans!!(vi_new, t)
919+
# Set parameters and add the logjac term.
920+
vi = unflatten(vi, y)
921+
if hasacc(vi, Val(:LogJacobian))
922+
vi = acclogjac!!(vi, logjac)
923+
end
924+
return settrans!!(vi, t)
912925
end
913926

914927
function invlink!!(
915928
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
916929
)
917930
b = t.bijector
918931
y = vi[:]
919-
x = b(y)
932+
x, inv_logjac = with_logabsdet_jacobian(b, y)
920933

921-
# Set parameters
922-
vi_new = unflatten(vi, x)
923-
# Reset logjac to 0.
924-
if hasacc(vi_new, Val(:LogJacobian))
925-
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian))
934+
# Mildly confusing: we need to _add_ the logjac of the inverse transform,
935+
# because we are trying to remove the logjac of the forward transform
936+
# that was previously accumulated when linking.
937+
vi = unflatten(vi, x)
938+
if hasacc(vi, Val(:LogJacobian))
939+
vi = acclogjac!!(vi, inv_logjac)
926940
end
927-
return settrans!!(vi_new, NoTransformation())
941+
return settrans!!(vi, NoTransformation())
928942
end
929943

930944
"""

src/accumulators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
1616
- `Base.copy(acc::T)`
1717
1818
In these functions:
19-
- `val` is the new value of the random variable sampled from a new distribution (always
20-
in the original unlinked space), or the value on the left-hand side of an observe
19+
- `val` is the new value of the random variable sampled from a distribution (always in
20+
the original unlinked space), or the value on the left-hand side of an observe
2121
statement.
2222
- `dist` is the distribution on the RHS of the tilde statement.
2323
- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the

src/default_accumulators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ $(TYPEDFIELDS)
129129
"""
130130
struct LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T}
131131
"the logabsdet of the link transform Jacobian"
132-
logJ::T
132+
logjac::T
133133
end
134134

135-
logp(acc::LogJacobianAccumulator) = acc.logJ
135+
logp(acc::LogJacobianAccumulator) = acc.logjac
136136

137137
accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian
138138

src/simple_varinfo.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,8 @@ function link!!(
622622
x = vi.values
623623
y, logjac = with_logabsdet_jacobian(b, x)
624624
vi_new = Accessors.@set(vi.values = y)
625-
# Since there's only a single transformation, we can overwrite any previous
626-
# value in logjac.
627625
if hasacc(vi_new, Val(:LogJacobian))
628-
vi_new = setlogjac!!(vi_new, logjac)
626+
vi_new = acclogjac!!(vi_new, logjac)
629627
end
630628
return settrans!!(vi_new, t)
631629
end
@@ -637,11 +635,13 @@ function invlink!!(
637635
)
638636
b = t.bijector
639637
y = vi.values
640-
x = b(y)
638+
x, inv_logjac = with_logabsdet_jacobian(b, y)
641639
vi_new = Accessors.@set(vi.values = x)
642-
# logjac should be zero for an unlinked VarInfo.
640+
# Mildly confusing: we need to _add_ the logjac of the inverse transform,
641+
# because we are trying to remove the logjac of the forward transform
642+
# that was previously accumulated when linking.
643643
if hasacc(vi_new, Val(:LogJacobian))
644-
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian))
644+
vi_new = acclogjac!!(vi_new, inv_logjac)
645645
end
646646
return settrans!!(vi_new, NoTransformation())
647647
end

0 commit comments

Comments
 (0)