@@ -300,20 +300,75 @@ end
300
300
301
301
varinfo (state:: GibbsState ) = state. vi
302
302
303
- function DynamicPPL. initialstep (
303
+ """
304
+ Initialise a VarInfo for the Gibbs sampler.
305
+
306
+ This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated here to
307
+ support calling both step and step_warmup as the initial step. DynamicPPL initialstep is
308
+ incompatible with step_warmup.
309
+ """
310
+ function initial_varinfo (rng, model, spl, initial_params)
311
+ vi = DynamicPPL. default_varinfo (rng, model, spl)
312
+
313
+ # Update the parameters if provided.
314
+ if initial_params != = nothing
315
+ vi = DynamicPPL. initialize_parameters!! (vi, initial_params, spl, model)
316
+
317
+ # Update joint log probability.
318
+ # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
319
+ # and https://github.com/TuringLang/Turing.jl/issues/1563
320
+ # to avoid that existing variables are resampled
321
+ vi = last (DynamicPPL. evaluate!! (model, vi, DynamicPPL. DefaultContext ()))
322
+ end
323
+ return vi
324
+ end
325
+
326
+ function AbstractMCMC. step (
304
327
rng:: Random.AbstractRNG ,
305
328
model:: DynamicPPL.Model ,
306
- spl:: DynamicPPL.Sampler{<:Gibbs} ,
307
- vi:: DynamicPPL.AbstractVarInfo ;
329
+ spl:: DynamicPPL.Sampler{<:Gibbs} ;
308
330
initial_params= nothing ,
309
331
kwargs... ,
310
332
)
311
333
alg = spl. alg
312
334
varnames = alg. varnames
313
335
samplers = alg. samplers
336
+ vi = initial_varinfo (rng, model, spl, initial_params)
314
337
315
338
vi, states = gibbs_initialstep_recursive (
316
- rng, model, varnames, samplers, vi; initial_params= initial_params, kwargs...
339
+ rng,
340
+ model,
341
+ AbstractMCMC. step,
342
+ varnames,
343
+ samplers,
344
+ vi;
345
+ initial_params= initial_params,
346
+ kwargs... ,
347
+ )
348
+ return Transition (model, vi), GibbsState (vi, states)
349
+ end
350
+
351
+ function AbstractMCMC. step_warmup (
352
+ rng:: Random.AbstractRNG ,
353
+ model:: DynamicPPL.Model ,
354
+ spl:: DynamicPPL.Sampler{<:Gibbs} ;
355
+ initial_params= nothing ,
356
+ kwargs... ,
357
+ )
358
+ alg = spl. alg
359
+ varnames = alg. varnames
360
+ samplers = alg. samplers
361
+ vi = initial_varinfo (rng, model, spl, initial_params)
362
+
363
+ vi, states = gibbs_initialstep_recursive (
364
+ rng,
365
+ model,
366
+ AbstractMCMC. step_warmup,
367
+ varnames,
368
+ samplers,
369
+ vi;
370
+ initial_params= initial_params,
371
+ kwargs... ,
317
372
)
318
373
return Transition (model, vi), GibbsState (vi, states)
319
374
end
322
377
Take the first step of MCMC for the first component sampler, and call the same function
323
378
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
324
379
and a tuple of initial states for all component samplers.
380
+
381
+ The `step_function` argument should always be either AbstractMCMC.step or
382
+ AbstractMCMC.step_warmup.
325
383
"""
326
384
function gibbs_initialstep_recursive (
327
- rng, model, varname_vecs, samplers, vi, states= (); initial_params= nothing , kwargs...
385
+ rng,
386
+ model,
387
+ step_function:: Function ,
388
+ varname_vecs,
389
+ samplers,
390
+ vi,
391
+ states= ();
392
+ initial_params= nothing ,
393
+ kwargs... ,
328
394
)
329
395
# End recursion
330
396
if isempty (varname_vecs) && isempty (samplers)
@@ -345,7 +411,7 @@ function gibbs_initialstep_recursive(
345
411
conditioned_model, context = make_conditional (model, varnames, vi)
346
412
347
413
# Take initial step with the current sampler.
348
- _, new_state = AbstractMCMC . step (
414
+ _, new_state = step_function (
349
415
rng,
350
416
conditioned_model,
351
417
sampler;
@@ -365,6 +431,7 @@ function gibbs_initialstep_recursive(
365
431
return gibbs_initialstep_recursive (
366
432
rng,
367
433
model,
434
+ step_function,
368
435
varname_vecs_tail,
369
436
samplers_tail,
370
437
vi,
@@ -388,7 +455,29 @@ function AbstractMCMC.step(
388
455
states = state. states
389
456
@assert length (samplers) == length (state. states)
390
457
391
- vi, states = gibbs_step_recursive (rng, model, varnames, samplers, states, vi; kwargs... )
458
+ vi, states = gibbs_step_recursive (
459
+ rng, model, AbstractMCMC. step, varnames, samplers, states, vi; kwargs...
460
+ )
461
+ return Transition (model, vi), GibbsState (vi, states)
462
+ end
463
+
464
+ function AbstractMCMC. step_warmup (
465
+ rng:: Random.AbstractRNG ,
466
+ model:: DynamicPPL.Model ,
467
+ spl:: DynamicPPL.Sampler{<:Gibbs} ,
468
+ state:: GibbsState ;
469
+ kwargs... ,
470
+ )
471
+ vi = varinfo (state)
472
+ alg = spl. alg
473
+ varnames = alg. varnames
474
+ samplers = alg. samplers
475
+ states = state. states
476
+ @assert length (samplers) == length (state. states)
477
+
478
+ vi, states = gibbs_step_recursive (
479
+ rng, model, AbstractMCMC. step_warmup, varnames, samplers, states, vi; kwargs...
480
+ )
392
481
return Transition (model, vi), GibbsState (vi, states)
393
482
end
394
483
@@ -517,10 +606,14 @@ end
517
606
"""
518
607
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
519
608
function on the tail, until there are no more samplers left.
609
+
610
+ The `step_function` argument should always be either AbstractMCMC.step or
611
+ AbstractMCMC.step_warmup.
520
612
"""
521
613
function gibbs_step_recursive (
522
614
rng:: Random.AbstractRNG ,
523
615
model:: DynamicPPL.Model ,
616
+ step_function:: Function ,
524
617
varname_vecs,
525
618
samplers,
526
619
states,
@@ -554,7 +647,7 @@ function gibbs_step_recursive(
554
647
state = setparams_varinfo!! (conditioned_model, sampler, state, vi)
555
648
556
649
# Take a step with the local sampler.
557
- new_state = last (AbstractMCMC . step (rng, conditioned_model, sampler, state; kwargs... ))
650
+ new_state = last (step_function (rng, conditioned_model, sampler, state; kwargs... ))
558
651
559
652
new_vi_local = varinfo (new_state)
560
653
# Merge the latest values for all the variables in the current sampler.
@@ -565,6 +658,7 @@ function gibbs_step_recursive(
565
658
return gibbs_step_recursive (
566
659
rng,
567
660
model,
661
+ step_function,
568
662
varname_vecs_tail,
569
663
samplers_tail,
570
664
states_tail,
0 commit comments