Skip to content

Commit 049b3d3

Browse files
committed
Fix logjac accumulation for StaticTransformation
1 parent 48a6048 commit 049b3d3

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

src/abstract_varinfo.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -907,28 +907,29 @@ function link!!(
907907
x = vi[:]
908908
y, logjac = with_logabsdet_jacobian(b, x)
909909

910-
# Set parameters
911-
vi_new = unflatten(vi, y)
912-
# Update logjac. We can overwrite any old value since there is only
913-
# a single logjac term to worry about.
914-
vi_new = setlogjac!!(vi_new, logjac)
915-
return settrans!!(vi_new, t)
910+
# Set parameters and add the logjac term.
911+
vi = unflatten(vi, y)
912+
if hasacc(vi, Val(:LogJacobian))
913+
vi = acclogjac!!(vi, logjac)
914+
end
915+
return settrans!!(vi, t)
916916
end
917917

918918
function invlink!!(
919919
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
920920
)
921921
b = t.bijector
922922
y = vi[:]
923-
x = b(y)
923+
x, inv_logjac = with_logabsdet_jacobian(b, y)
924924

925-
# Set parameters
926-
vi_new = unflatten(vi, x)
927-
# Reset logjac to 0.
928-
if hasacc(vi_new, Val(:LogJacobian))
929-
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian))
925+
# Mildly confusing: we need to _add_ the logjac of the inverse transform,
926+
# because we are trying to remove the logjac of the forward transform
927+
# that was previously accumulated when linking.
928+
vi = unflatten(vi, x)
929+
if hasacc(vi, Val(:LogJacobian))
930+
vi = acclogjac!!(vi, inv_logjac)
930931
end
931-
return settrans!!(vi_new, NoTransformation())
932+
return settrans!!(vi, NoTransformation())
932933
end
933934

934935
"""

src/simple_varinfo.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,8 @@ function link!!(
620620
x = vi.values
621621
y, logjac = with_logabsdet_jacobian(b, x)
622622
vi_new = Accessors.@set(vi.values = y)
623-
# Since there's only a single transformation, we can overwrite any previous
624-
# value in logjac.
625623
if hasacc(vi_new, Val(:LogJacobian))
626-
vi_new = setlogjac!!(vi_new, logjac)
624+
vi_new = acclogjac!!(vi_new, logjac)
627625
end
628626
return settrans!!(vi_new, t)
629627
end
@@ -635,11 +633,13 @@ function invlink!!(
635633
)
636634
b = t.bijector
637635
y = vi.values
638-
x = b(y)
636+
x, inv_logjac = with_logabsdet_jacobian(b, y)
639637
vi_new = Accessors.@set(vi.values = x)
640-
# logjac should be zero for an unlinked VarInfo.
638+
# Mildly confusing: we need to _add_ the logjac of the inverse transform,
639+
# because we are trying to remove the logjac of the forward transform
640+
# that was previously accumulated when linking.
641641
if hasacc(vi_new, Val(:LogJacobian))
642-
vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian))
642+
vi_new = acclogjac!!(vi_new, inv_logjac)
643643
end
644644
return settrans!!(vi_new, NoTransformation())
645645
end

0 commit comments

Comments
 (0)