Skip to content

Commit 10a130a

Browse files
committed
fix HMC log-density
1 parent fd5a815 commit 10a130a

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/mcmc/hmc.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ function DynamicPPL.initialstep(
206206
end
207207
theta = vi[:]
208208

209-
# Cache current log density.
210-
log_density_old = getloglikelihood(vi)
209+
# Cache current log density. We will reuse this if the transition is rejected.
210+
logp_old = DynamicPPL.getlogp(vi)
211211

212212
# Find good eps if not provided one
213213
if iszero(spl.alg.ϵ)
@@ -232,15 +232,21 @@ function DynamicPPL.initialstep(
232232
)
233233
end
234234

235-
# Update `vi` based on acceptance
235+
# Update VarInfo based on acceptance
236236
if t.stat.is_accept
237237
vi = DynamicPPL.unflatten(vi, t.z.θ)
238-
# TODO(mhauru) Is setloglikelihood! the right thing here?
239-
vi = setloglikelihood!!(vi, t.stat.log_density)
238+
# Re-evaluate to calculate log probability density.
239+
# TODO(penelopeysm): This seems a little bit wasteful. The need for
240+
# this stems from the fact that the HMC sampler doesn't keep track of
241+
# prior and likelihood separately but rather a single log-joint, for
242+
# which we have no way to decompose this back into prior and
243+
# likelihood. I don't immediately see how to solve this without
244+
# re-evaluating the model.
245+
vi = DynamicPPL.evaluate!!(model, vi)
240246
else
247+
# Reset VarInfo back to its original state.
241248
vi = DynamicPPL.unflatten(vi, theta)
242-
# TODO(mhauru) Is setloglikelihood! the right thing here?
243-
vi = setloglikelihood!!(vi, log_density_old)
249+
vi = DynamicPPL.setlogp!!(vi, logp_old)
244250
end
245251

246252
transition = Transition(model, vi, t)

0 commit comments

Comments
 (0)