@@ -543,6 +543,49 @@ function setparams_varinfo!!(
543
543
return PGState (params, state. rng)
544
544
end
545
545
546
+ """
547
+ match_linking!!(varinfo_local, prev_state_local, model)
548
+
549
+ Make sure the linked/invlinked status of varinfo_local matches that of the previous
550
+ state for this sampler. This is relevant when multilple samplers are sampling the same
551
+ variables, and one might need it to be linked while the other doesn't.
552
+ """
553
+ function match_linking!! (varinfo_local, prev_state_local, model)
554
+ prev_varinfo_local = varinfo (prev_state_local)
555
+ was_linked = DynamicPPL. istrans (prev_varinfo_local)
556
+ is_linked = DynamicPPL. istrans (varinfo_local)
557
+ if was_linked && ! is_linked
558
+ varinfo_local = DynamicPPL. link!! (varinfo_local, model)
559
+ elseif ! was_linked && is_linked
560
+ varinfo_local = DynamicPPL. invlink!! (varinfo_local, model)
561
+ end
562
+ # TODO (mhauru) The above might run into trouble if some variables are linked and others
563
+ # are not. `istrans(varinfo)` returns an `all` over the individual variables. This could
564
+ # especially be a problem with dynamic models, where new variables may get introduced,
565
+ # but also in cases where component samplers have partial overlap in their target
566
+ # variables. The below is how I would like to implement this, but DynamicPPL at this
567
+ # time does not support linking individual variables selected by `VarName`. It soon
568
+ # should though, so come back to this.
569
+ # prev_links_dict = Dict(vn => DynamicPPL.istrans(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local))
570
+ # any_linked = any(values(prev_links_dict))
571
+ # for vn in keys(varinfo_local)
572
+ # was_linked = if haskey(prev_varinfo_local, vn)
573
+ # prev_links_dict[vn]
574
+ # else
575
+ # # If the old state didn't have this variable, we assume it was linked if _any_
576
+ # # of the variables of the old state were linked.
577
+ # any_linked
578
+ # end
579
+ # is_linked = DynamicPPL.istrans(varinfo_local, vn)
580
+ # if was_linked && !is_linked
581
+ # varinfo_local = DynamicPPL.invlink!!(varinfo_local, vn)
582
+ # elseif !was_linked && is_linked
583
+ # varinfo_local = DynamicPPL.link!!(varinfo_local, vn)
584
+ # end
585
+ # end
586
+ return varinfo_local
587
+ end
588
+
546
589
function gibbs_step_inner (
547
590
rng:: Random.AbstractRNG ,
548
591
model:: DynamicPPL.Model ,
@@ -555,6 +598,7 @@ function gibbs_step_inner(
555
598
# Construct the conditional model and the varinfo that this sampler should use.
556
599
model_local, context_local = make_conditional (model, varnames_local, global_vi)
557
600
varinfo_local = subset (global_vi, varnames_local)
601
+ varinfo_local = match_linking!! (varinfo_local, state_local, model)
558
602
559
603
# TODO (mhauru) The below may be overkill. If the varnames for this sampler are not
560
604
# sampled by other samplers, we don't need to `setparams`, but could rather simply
0 commit comments