Skip to content

Commit 68b974a

Browse files
committed
Make DynamicTransformation not use accumulators other than LogPrior
1 parent 3f195e5 commit 68b974a

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

src/transforming.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ function tilde_assume(
2727
# Only transform if `!isinverse` since `vi[vn, right]`
2828
# already performs the inverse transformation if it's transformed.
2929
r_transformed = isinverse ? r : link_transform(right)(r)
30-
vi = acclogprior!!(vi, lp)
30+
if hasacc(vi, Val(:LogPrior))
31+
vi = acclogprior!!(vi, lp)
32+
end
3133
return r, setindex!!(vi, r_transformed, vn)
3234
end
3335

@@ -36,14 +38,36 @@ function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi)
3638
end
3739

3840
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
39-
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
41+
return _transform!!(t, DynamicTransformationContext{false}(), vi, model)
4042
end
4143

4244
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
43-
return settrans!!(
44-
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
45-
NoTransformation(),
46-
)
45+
return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model)
46+
end
47+
48+
function _transform(
49+
t::AbstractTransformation,
50+
ctx::DynamicTransformationContext,
51+
vi::AbstractVarInfo,
52+
model::Model,
53+
)
54+
# To transform using DynamicTransformationContext, we evaluate the model, but we do not
55+
# need to use any accumulators other than LogPrior (which is affected by the Jacobian of
56+
# the transformation).
57+
accs = getaccs(vi.accs)
58+
has_logprior = hasacc(accs, Val(:LogPrior))
59+
if has_logprior
60+
old_logprior = getacc(accs, Val(:LogPrior))
61+
vi = setaccs!!(vi, (old_logprior,))
62+
end
63+
vi = settrans!!(last(evaluate!!(model, vi, ctx)), t)
64+
# Restore the accumulators.
65+
if has_logprior
66+
new_logprior = getacc(vi, Val(:LogPrior))
67+
accs = setacc!!(accs, new_logprior)
68+
end
69+
vi = setaccs!!(vi, accs)
70+
return vi
4771
end
4872

4973
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)

0 commit comments

Comments
 (0)