Skip to content

Commit c795e58

Browse files
committed
Fix tests
1 parent ca06c8b commit c795e58

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

src/mcmc/mh.jl

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,25 @@ function MH(model::Model; proposal_type=AMH.StaticProposal)
153153
return AMH.MetropolisHastings(priors)
154154
end
155155

156+
"""
157+
MHState(varinfo::AbstractVarInfo, logjoint_internal::Real)
158+
159+
State for Metropolis-Hastings sampling.
160+
161+
`varinfo` must have the correct parameters set inside it, but its other fields
162+
(e.g. accumulators, which track logp) can in general be missing or incorrect.
163+
164+
`logjoint_internal` is the log joint probability of the model, evaluated using
165+
the parameters and linking status of `varinfo`. It should be equal to
166+
`DynamicPPL.getlogjoint_internal(varinfo)`. This information is returned by the
167+
MH sampler so we store this here to avoid re-evaluating the model
168+
unnecessarily.
169+
"""
170+
struct MHState{V<:AbstractVarInfo,L<:Real}
171+
varinfo::V
172+
logjoint_internal::L
173+
end
174+
156175
#####################
157176
# Utility functions #
158177
#####################
@@ -297,14 +316,15 @@ end
297316

298317
# Make a proposal if we don't have a covariance proposal matrix (the default).
299318
function propose!!(
300-
rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal
319+
rng::AbstractRNG, prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal
301320
)
321+
vi = prev_state.varinfo
302322
# Retrieve distribution and value NamedTuples.
303323
dt, vt = dist_val_tuple(spl, vi)
304324

305325
# Create a sampler and the previous transition.
306326
mh_sampler = AMH.MetropolisHastings(dt)
307-
prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint_internal(vi), false)
327+
prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false)
308328

309329
# Make a new transition.
310330
spl_model = DynamicPPL.contextualize(
@@ -319,24 +339,29 @@ function propose!!(
319339
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
320340
# trans.params isa NamedTuple
321341
set_namedtuple!(vi, trans.params)
322-
return vi
342+
# Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know
343+
# how to set this back inside vi (without re-evaluating). However, the next
344+
# MH step will require this information to calculate the acceptance
345+
# probability, so we return it together with vi.
346+
return MHState(vi, trans.lp)
323347
end
324348

325349
# Make a proposal if we DO have a covariance proposal matrix.
326350
function propose!!(
327351
rng::AbstractRNG,
328-
vi::AbstractVarInfo,
352+
prev_state::MHState,
329353
model::Model,
330354
spl::Sampler{<:MH},
331355
proposal::AdvancedMH.RandomWalkProposal,
332356
)
357+
vi = prev_state.varinfo
333358
# If this is the case, we can just draw directly from the proposal
334359
# matrix.
335360
vals = vi[:]
336361

337362
# Create a sampler and the previous transition.
338363
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
339-
prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint_internal(vi), false)
364+
prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false)
340365

341366
# Make a new transition.
342367
spl_model = DynamicPPL.contextualize(
@@ -350,7 +375,12 @@ function propose!!(
350375
)
351376
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
352377
# trans.params isa AbstractVector
353-
return DynamicPPL.unflatten(vi, trans.params)
378+
vi = DynamicPPL.unflatten(vi, trans.params)
379+
# Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know
380+
# how to set this back inside vi (without re-evaluating). However, the next
381+
# MH step will require this information to calculate the acceptance
382+
# probability, so we return it together with vi.
383+
return MHState(vi, trans.lp)
354384
end
355385

356386
function DynamicPPL.initialstep(
@@ -364,18 +394,18 @@ function DynamicPPL.initialstep(
364394
# just link everything before sampling.
365395
vi = maybe_link!!(vi, spl, spl.alg.proposals, model)
366396

367-
return Transition(model, vi, nothing), vi
397+
return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi))
368398
end
369399

370400
function AbstractMCMC.step(
371-
rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs...
401+
rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, state::MHState; kwargs...
372402
)
373403
# Cases:
374404
# 1. A covariance proposal matrix
375405
# 2. A bunch of NamedTuples that specify the proposal space
376-
new_vi = propose!!(rng, vi, model, spl, spl.alg.proposals)
406+
new_state = propose!!(rng, state, model, spl, spl.alg.proposals)
377407

378-
return Transition(model, new_vi, nothing), new_vi
408+
return Transition(model, new_state.varinfo, nothing), new_state
379409
end
380410

381411
####

test/mcmc/is.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ using Turing
4747

4848
Random.seed!(seed)
4949
chain = sample(model, alg, n; check_model=false)
50-
sampled = get(chain, [:a, :b, :lp])
50+
sampled = get(chain, [:a, :b, :loglikelihood])
5151

5252
@test vec(sampled.a) == ref.as
5353
@test vec(sampled.b) == ref.bs
54-
@test vec(sampled.lp) == ref.logps
54+
@test vec(sampled.loglikelihood) == ref.logps
5555
@test chain.logevidence == ref.logevidence
5656
end
5757

0 commit comments

Comments
 (0)