@@ -907,28 +907,29 @@ function link!!(
907
907
x = vi[:]
908
908
y, logjac = with_logabsdet_jacobian (b, x)
909
909
910
- # Set parameters
911
- vi_new = unflatten (vi, y)
912
- # Update logjac. We can overwrite any old value since there is only
913
- # a single logjac term to worry about.
914
- vi_new = setlogjac!! (vi_new, logjac)
915
- return settrans!! (vi_new , t)
910
+ # Set parameters and add the logjac term.
911
+ vi = unflatten (vi, y)
912
+ if hasacc (vi, Val ( :LogJacobian ))
913
+ vi = acclogjac!! (vi, logjac)
914
+ end
915
+ return settrans!! (vi , t)
916
916
end
917
917
918
918
function invlink!! (
919
919
t:: StaticTransformation{<:Bijectors.Transform} , vi:: AbstractVarInfo , :: Model
920
920
)
921
921
b = t. bijector
922
922
y = vi[:]
923
- x = b ( y)
923
+ x, inv_logjac = with_logabsdet_jacobian (b, y)
924
924
925
- # Set parameters
926
- vi_new = unflatten (vi, x)
927
- # Reset logjac to 0.
928
- if hasacc (vi_new, Val (:LogJacobian ))
929
- vi_new = map_accumulator!! (zero, vi_new, Val (:LogJacobian ))
925
+ # Mildly confusing: we need to _add_ the logjac of the inverse transform,
926
+ # because we are trying to remove the logjac of the forward transform
927
+ # that was previously accumulated when linking.
928
+ vi = unflatten (vi, x)
929
+ if hasacc (vi, Val (:LogJacobian ))
930
+ vi = acclogjac!! (vi, inv_logjac)
930
931
end
931
- return settrans!! (vi_new , NoTransformation ())
932
+ return settrans!! (vi , NoTransformation ())
932
933
end
933
934
934
935
"""
0 commit comments