Skip to content

Commit cd52e9f

Browse files
committed
even more fixes (oh goodness when will this end)
1 parent 195f819 commit cd52e9f

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

src/mcmc/emcee.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,14 @@ function AbstractMCMC.step(
5353
length(initial_params) == n ||
5454
throw(ArgumentError("initial parameters have to be specified for each walker"))
5555
vis = map(vis, initial_params) do vi, init
56+
# TODO(DPPL0.37/penelopeysm) This whole thing can be replaced with init!!
5657
vi = DynamicPPL.initialize_parameters!!(vi, init, model)
5758

5859
# Update log joint probability.
59-
last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior()))
60+
spl_model = DynamicPPL.contextualize(
61+
model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context)
62+
)
63+
last(DynamicPPL.evaluate!!(spl_model, vi))
6064
end
6165
end
6266

@@ -68,7 +72,7 @@ function AbstractMCMC.step(
6872
vis[1],
6973
map(vis) do vi
7074
vi = DynamicPPL.link!!(vi, model)
71-
AMH.Transition(vi[:], getlogp(vi), false)
75+
AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false)
7276
end,
7377
)
7478

src/mcmc/hmc.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,12 @@ function DynamicPPL.initialstep(
236236
if t.stat.is_accept
237237
vi = DynamicPPL.unflatten(vi, t.z.θ)
238238
# Re-evaluate to calculate log probability density.
239-
# TODO(penelopeysm): This seems a little bit wasteful. The need for
240-
# this stems from the fact that the HMC sampler doesn't keep track of
241-
# prior and likelihood separately but rather a single log-joint, for
242-
# which we have no way to decompose this back into prior and
243-
# likelihood. I don't immediately see how to solve this without
244-
# re-evaluating the model.
239+
# TODO(penelopeysm): This seems a little bit wasteful. Unfortunately,
240+
# even though `t.stat.log_density` contains some kind of logp, this
241+
# doesn't track prior and likelihood separately but rather a single
242+
# log-joint (and in linked space), so which we have no way to decompose
243+
# this back into prior and likelihood. I don't immediately see how to
244+
# solve this without re-evaluating the model.
245245
_, vi = DynamicPPL.evaluate!!(model, vi)
246246
else
247247
# Reset VarInfo back to its original state.
@@ -291,8 +291,9 @@ function AbstractMCMC.step(
291291
vi = state.vi
292292
if t.stat.is_accept
293293
vi = DynamicPPL.unflatten(vi, t.z.θ)
294-
# TODO(mhauru) Is setloglikelihood! the right thing here?
295-
vi = setloglikelihood!!(vi, t.stat.log_density)
294+
# Re-evaluate to calculate log probability density.
295+
# TODO(penelopeysm): This seems a little bit wasteful. See note above.
296+
_, vi = DynamicPPL.evaluate!!(model, vi)
296297
end
297298

298299
# Compute next transition and state.

src/mcmc/mh.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple)
195195
vi = deepcopy(f.varinfo)
196196
set_namedtuple!(vi, x)
197197
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context))
198-
lj = getlogp(vi_new)
198+
lj = f.getlogdensity(vi_new)
199199
return lj
200200
end
201201

@@ -304,7 +304,7 @@ function propose!!(
304304

305305
# Create a sampler and the previous transition.
306306
mh_sampler = AMH.MetropolisHastings(dt)
307-
prev_trans = AMH.Transition(vt, getlogp(vi), false)
307+
prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint(vi), false)
308308

309309
# Make a new transition.
310310
densitymodel = AMH.DensityModel(
@@ -339,7 +339,7 @@ function propose!!(
339339

340340
# Create a sampler and the previous transition.
341341
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
342-
prev_trans = AMH.Transition(vals, getlogp(vi), false)
342+
prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint(vi), false)
343343

344344
# Make a new transition.
345345
densitymodel = AMH.DensityModel(

0 commit comments

Comments
 (0)