Skip to content

Commit 519ff02

Browse files
committed
Fix Gibbs linking bug, add tests
1 parent 310bee9 commit 519ff02

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

src/mcmc/gibbs.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,49 @@ function setparams_varinfo!!(
543543
return PGState(params, state.rng)
544544
end
545545

546+
"""
547+
match_linking!!(varinfo_local, prev_state_local, model)
548+
549+
Make sure the linked/invlinked status of varinfo_local matches that of the previous
550+
state for this sampler. This is relevant when multilple samplers are sampling the same
551+
variables, and one might need it to be linked while the other doesn't.
552+
"""
553+
function match_linking!!(varinfo_local, prev_state_local, model)
554+
prev_varinfo_local = varinfo(prev_state_local)
555+
was_linked = DynamicPPL.istrans(prev_varinfo_local)
556+
is_linked = DynamicPPL.istrans(varinfo_local)
557+
if was_linked && !is_linked
558+
varinfo_local = DynamicPPL.link!!(varinfo_local, model)
559+
elseif !was_linked && is_linked
560+
varinfo_local = DynamicPPL.invlink!!(varinfo_local, model)
561+
end
562+
# TODO(mhauru) The above might run into trouble if some variables are linked and others
563+
# are not. `istrans(varinfo)` returns an `all` over the individual variables. This could
564+
# especially be a problem with dynamic models, where new variables may get introduced,
565+
# but also in cases where component samplers have partial overlap in their target
566+
# variables. The below is how I would like to implement this, but DynamicPPL at this
567+
# time does not support linking individual variables selected by `VarName`. It soon
568+
# should though, so come back to this.
569+
# prev_links_dict = Dict(vn => DynamicPPL.istrans(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local))
570+
# any_linked = any(values(prev_links_dict))
571+
# for vn in keys(varinfo_local)
572+
# was_linked = if haskey(prev_varinfo_local, vn)
573+
# prev_links_dict[vn]
574+
# else
575+
# # If the old state didn't have this variable, we assume it was linked if _any_
576+
# # of the variables of the old state were linked.
577+
# any_linked
578+
# end
579+
# is_linked = DynamicPPL.istrans(varinfo_local, vn)
580+
# if was_linked && !is_linked
581+
# varinfo_local = DynamicPPL.invlink!!(varinfo_local, vn)
582+
# elseif !was_linked && is_linked
583+
# varinfo_local = DynamicPPL.link!!(varinfo_local, vn)
584+
# end
585+
# end
586+
return varinfo_local
587+
end
588+
546589
function gibbs_step_inner(
547590
rng::Random.AbstractRNG,
548591
model::DynamicPPL.Model,
@@ -555,6 +598,7 @@ function gibbs_step_inner(
555598
# Construct the conditional model and the varinfo that this sampler should use.
556599
model_local, context_local = make_conditional(model, varnames_local, global_vi)
557600
varinfo_local = subset(global_vi, varnames_local)
601+
varinfo_local = match_linking!!(varinfo_local, state_local, model)
558602

559603
# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
560604
# sampled by other samplers, we don't need to `setparams`, but could rather simply

test/mcmc/gibbs.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -350,22 +350,30 @@ end
350350
check_MoGtest_default(chain; atol=0.15)
351351

352352
Random.seed!(200)
353-
for alg in [
354-
# The new syntax for specifying a sampler to run twice for one variable.
355-
Gibbs(
356-
@varname(s) => MH(),
357-
@varname(s) => MH(),
358-
@varname(m) => HMC(0.2, 4; adtype=adbackend),
359-
),
360-
Gibbs(
361-
@varname(s) => MH(),
362-
@varname(m) => HMC(0.2, 4; adtype=adbackend),
363-
@varname(m) => HMC(0.2, 4; adtype=adbackend),
364-
),
365-
]
366-
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
367-
check_gdemo(chain; atol=0.15)
368-
end
353+
# Test samplers that are run multiple times, or have overlapping targets.
354+
alg = Gibbs(
355+
@varname(s) => MH(),
356+
(@varname(s), @varname(m)) => MH(),
357+
@varname(m) => ESS(),
358+
@varname(s) => MH(),
359+
@varname(m) => HMC(0.2, 4; adtype=adbackend),
360+
(@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend),
361+
)
362+
chain = sample(gdemo(1.5, 2.0), alg, 300)
363+
check_gdemo(chain; atol=0.15)
364+
365+
Random.seed!(200)
366+
gibbs = Gibbs(
367+
(@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
368+
(@varname(z1), @varname(z2)) => PG(15),
369+
(@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
370+
(@varname(z3), @varname(z4)) => PG(15),
371+
(@varname(mu1)) => ESS(),
372+
(@varname(mu2)) => ESS(),
373+
(@varname(z1), @varname(z2)) => PG(15),
374+
)
375+
chain = sample(MoGtest_default, gibbs, 300)
376+
check_MoGtest_default(chain; atol=0.15)
369377
end
370378

371379
@testset "transitions" begin

0 commit comments

Comments
 (0)