@@ -204,7 +204,7 @@ Jacobian here is taken with respect to the forward (link) transform.
204
204
205
205
See also: [`setlogjac!!`](@ref).
206
206
"""
207
- getlogjac (vi:: AbstractVarInfo ) = getacc (vi, Val (:LogJacobian )). logJ
207
+ getlogjac (vi:: AbstractVarInfo ) = getacc (vi, Val (:LogJacobian )). logjac
208
208
209
209
"""
210
210
getloglikelihood(vi::AbstractVarInfo)
@@ -239,14 +239,14 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
239
239
setlogprior!! (vi:: AbstractVarInfo , logp) = setacc!! (vi, LogPriorAccumulator (logp))
240
240
241
241
"""
242
- setlogjac!!(vi::AbstractVarInfo, logJ )
242
+ setlogjac!!(vi::AbstractVarInfo, logjac )
243
243
244
244
Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
245
245
Jacobian here is taken with respect to the forward (link) transform.
246
246
247
247
See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref).
248
248
"""
249
- setlogjac!! (vi:: AbstractVarInfo , logJ ) = setacc!! (vi, LogJacobianAccumulator (logJ ))
249
+ setlogjac!! (vi:: AbstractVarInfo , logjac ) = setacc!! (vi, LogJacobianAccumulator (logjac ))
250
250
251
251
"""
252
252
setloglikelihood!!(vi::AbstractVarInfo, logp)
@@ -372,6 +372,19 @@ function acclogjac!!(vi::AbstractVarInfo, logJ)
372
372
return map_accumulator!! (acc -> acclogp (acc, logJ), vi, Val (:LogJacobian ))
373
373
end
374
374
375
+ """
376
+ acclogjac!!(vi::AbstractVarInfo, logjac)
377
+
378
+ Add `logjac` to the value of the log Jacobian in `vi`.
379
+
380
+ See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
381
+ """
382
+ function acclogjac!! (vi:: AbstractVarInfo , logjac)
383
+ return map_accumulator!! (
384
+ acc -> acc + LogJacobianAccumulator (logjac), vi, Val (:LogJacobian )
385
+ )
386
+ end
387
+
375
388
"""
376
389
accloglikelihood!!(vi::AbstractVarInfo, logp)
377
390
@@ -903,28 +916,29 @@ function link!!(
903
916
x = vi[:]
904
917
y, logjac = with_logabsdet_jacobian (b, x)
905
918
906
- # Set parameters
907
- vi_new = unflatten (vi, y)
908
- # Update logjac. We can overwrite any old value since there is only
909
- # a single logjac term to worry about.
910
- vi_new = setlogjac!! (vi_new, logjac)
911
- return settrans!! (vi_new , t)
919
+ # Set parameters and add the logjac term.
920
+ vi = unflatten (vi, y)
921
+ if hasacc (vi, Val ( :LogJacobian ))
922
+ vi = acclogjac!! (vi, logjac)
923
+ end
924
+ return settrans!! (vi , t)
912
925
end
913
926
914
927
function invlink!! (
915
928
t:: StaticTransformation{<:Bijectors.Transform} , vi:: AbstractVarInfo , :: Model
916
929
)
917
930
b = t. bijector
918
931
y = vi[:]
919
- x = b ( y)
932
+ x, inv_logjac = with_logabsdet_jacobian (b, y)
920
933
921
- # Set parameters
922
- vi_new = unflatten (vi, x)
923
- # Reset logjac to 0.
924
- if hasacc (vi_new, Val (:LogJacobian ))
925
- vi_new = map_accumulator!! (zero, vi_new, Val (:LogJacobian ))
934
+ # Mildly confusing: we need to _add_ the logjac of the inverse transform,
935
+ # because we are trying to remove the logjac of the forward transform
936
+ # that was previously accumulated when linking.
937
+ vi = unflatten (vi, x)
938
+ if hasacc (vi, Val (:LogJacobian ))
939
+ vi = acclogjac!! (vi, inv_logjac)
926
940
end
927
- return settrans!! (vi_new , NoTransformation ())
941
+ return settrans!! (vi , NoTransformation ())
928
942
end
929
943
930
944
"""
0 commit comments