@@ -153,6 +153,25 @@ function MH(model::Model; proposal_type=AMH.StaticProposal)
153
153
return AMH. MetropolisHastings (priors)
154
154
end
155
155
156
+ """
157
+ MHState(varinfo::AbstractVarInfo, logjoint_internal::Real)
158
+
159
+ State for Metropolis-Hastings sampling.
160
+
161
+ `varinfo` must have the correct parameters set inside it, but its other fields
162
+ (e.g. accumulators, which track logp) can in general be missing or incorrect.
163
+
164
+ `logjoint_internal` is the log joint probability of the model, evaluated using
165
+ the parameters and linking status of `varinfo`. It should be equal to
166
+ `DynamicPPL.getlogjoint_internal(varinfo)`. This information is returned by the
167
+ MH sampler so we store this here to avoid re-evaluating the model
168
+ unnecessarily.
169
+ """
170
+ struct MHState{V<: AbstractVarInfo ,L<: Real }
171
+ varinfo:: V
172
+ logjoint_internal:: L
173
+ end
174
+
156
175
# ####################
157
176
# Utility functions #
158
177
# ####################
@@ -297,14 +316,15 @@ end
297
316
298
317
# Make a proposal if we don't have a covariance proposal matrix (the default).
299
318
function propose!! (
300
- rng:: AbstractRNG , vi :: AbstractVarInfo , model:: Model , spl:: Sampler{<:MH} , proposal
319
+ rng:: AbstractRNG , prev_state :: MHState , model:: Model , spl:: Sampler{<:MH} , proposal
301
320
)
321
+ vi = prev_state. varinfo
302
322
# Retrieve distribution and value NamedTuples.
303
323
dt, vt = dist_val_tuple (spl, vi)
304
324
305
325
# Create a sampler and the previous transition.
306
326
mh_sampler = AMH. MetropolisHastings (dt)
307
- prev_trans = AMH. Transition (vt, DynamicPPL . getlogjoint_internal (vi) , false )
327
+ prev_trans = AMH. Transition (vt, prev_state . logjoint_internal , false )
308
328
309
329
# Make a new transition.
310
330
spl_model = DynamicPPL. contextualize (
@@ -319,24 +339,29 @@ function propose!!(
319
339
trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
320
340
# trans.params isa NamedTuple
321
341
set_namedtuple! (vi, trans. params)
322
- return vi
342
+ # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know
343
+ # how to set this back inside vi (without re-evaluating). However, the next
344
+ # MH step will require this information to calculate the acceptance
345
+ # probability, so we return it together with vi.
346
+ return MHState (vi, trans. lp)
323
347
end
324
348
325
349
# Make a proposal if we DO have a covariance proposal matrix.
326
350
function propose!! (
327
351
rng:: AbstractRNG ,
328
- vi :: AbstractVarInfo ,
352
+ prev_state :: MHState ,
329
353
model:: Model ,
330
354
spl:: Sampler{<:MH} ,
331
355
proposal:: AdvancedMH.RandomWalkProposal ,
332
356
)
357
+ vi = prev_state. varinfo
333
358
# If this is the case, we can just draw directly from the proposal
334
359
# matrix.
335
360
vals = vi[:]
336
361
337
362
# Create a sampler and the previous transition.
338
363
mh_sampler = AMH. MetropolisHastings (spl. alg. proposals)
339
- prev_trans = AMH. Transition (vals, DynamicPPL . getlogjoint_internal (vi) , false )
364
+ prev_trans = AMH. Transition (vals, prev_state . logjoint_internal , false )
340
365
341
366
# Make a new transition.
342
367
spl_model = DynamicPPL. contextualize (
@@ -350,7 +375,12 @@ function propose!!(
350
375
)
351
376
trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
352
377
# trans.params isa AbstractVector
353
- return DynamicPPL. unflatten (vi, trans. params)
378
+ vi = DynamicPPL. unflatten (vi, trans. params)
379
+ # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know
380
+ # how to set this back inside vi (without re-evaluating). However, the next
381
+ # MH step will require this information to calculate the acceptance
382
+ # probability, so we return it together with vi.
383
+ return MHState (vi, trans. lp)
354
384
end
355
385
356
386
function DynamicPPL. initialstep (
@@ -364,18 +394,18 @@ function DynamicPPL.initialstep(
364
394
# just link everything before sampling.
365
395
vi = maybe_link!! (vi, spl, spl. alg. proposals, model)
366
396
367
- return Transition (model, vi, nothing ), vi
397
+ return Transition (model, vi, nothing ), MHState (vi, DynamicPPL . getlogjoint_internal (vi))
368
398
end
369
399
370
400
function AbstractMCMC. step (
371
- rng:: AbstractRNG , model:: Model , spl:: Sampler{<:MH} , vi :: AbstractVarInfo ; kwargs...
401
+ rng:: AbstractRNG , model:: Model , spl:: Sampler{<:MH} , state :: MHState ; kwargs...
372
402
)
373
403
# Cases:
374
404
# 1. A covariance proposal matrix
375
405
# 2. A bunch of NamedTuples that specify the proposal space
376
- new_vi = propose!! (rng, vi , model, spl, spl. alg. proposals)
406
+ new_state = propose!! (rng, state , model, spl, spl. alg. proposals)
377
407
378
- return Transition (model, new_vi , nothing ), new_vi
408
+ return Transition (model, new_state . varinfo , nothing ), new_state
379
409
end
380
410
381
411
# ###
0 commit comments