Skip to content

Commit 02d1d0e

Browse files
committed
[no ci] Fix Gibbs
1 parent 9bc58c8 commit 02d1d0e

File tree

1 file changed

+7
-53
lines changed

1 file changed

+7
-53
lines changed

src/mcmc/gibbs.jl

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName})
140140
end
141141

142142
# Tilde pipeline
143-
function DynamicPPL.tilde_assume!!(context::GibbsContext, right, vn, vi)
143+
function DynamicPPL.tilde_assume!!(
144+
context::GibbsContext, right::Distribution, vn::VarName, vi::DynamicPPL.AbstractVarInfo
145+
)
144146
child_context = DynamicPPL.childcontext(context)
145147

146148
# Note that `child_context` may contain `PrefixContext`s -- in which case
@@ -204,47 +206,6 @@ function DynamicPPL.tilde_assume!!(context::GibbsContext, right, vn, vi)
204206
end
205207
end
206208

207-
# As above but with an RNG.
208-
function DynamicPPL.tilde_assume!!(
209-
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi
210-
)
211-
# See comment in the above, rng-less version of this method for an explanation.
212-
child_context = DynamicPPL.childcontext(context)
213-
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
214-
215-
return if is_target_varname(context, vn)
216-
# This branch means that that `sampler` is supposed to handle
217-
# this variable. We can thus use its default behaviour, with
218-
# the 'local' sampler-specific VarInfo.
219-
DynamicPPL.tilde_assume!!(rng, child_context, sampler, right, vn, vi)
220-
elseif has_conditioned_gibbs(context, vn)
221-
# This branch means that a different sampler is supposed to handle this
222-
# variable. From the perspective of this sampler, this variable is
223-
# conditioned on, so we can just treat it as an observation.
224-
# The only catch is that the value that we need is to be obtained from
225-
# the global VarInfo (since the local VarInfo has no knowledge of it).
226-
# Note that tilde_observe!! will trigger resampling in particle methods
227-
# for variables that are handled by other Gibbs component samplers.
228-
val = get_conditioned_gibbs(context, vn)
229-
DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi)
230-
else
231-
# If the varname has not been conditioned on, nor is it a target variable, its
232-
# presumably a new variable that should be sampled from its prior. We need to add
233-
# this new variable to the global `varinfo` of the context, but not to the local one
234-
# being used by the current sampler.
235-
value, new_global_vi = DynamicPPL.tilde_assume!!(
236-
# child_context might be a PrefixContext so we have to be careful to not
237-
# overwrite it.
238-
DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext(rng)),
239-
right,
240-
vn,
241-
get_global_varinfo(context),
242-
)
243-
set_global_varinfo!(context, new_global_vi)
244-
value, vi
245-
end
246-
end
247-
248209
"""
249210
make_conditional(model, target_variables, varinfo)
250211
@@ -363,7 +324,7 @@ function AbstractMCMC.step(
363324
rng::Random.AbstractRNG,
364325
model::DynamicPPL.Model,
365326
spl::DynamicPPL.Sampler{<:Gibbs};
366-
initial_params=nothing,
327+
initial_params::DynamicPPL.AbstractInitStrategy=DynamicPPL.init_strategy(spl),
367328
kwargs...,
368329
)
369330
alg = spl.alg
@@ -388,7 +349,7 @@ function AbstractMCMC.step_warmup(
388349
rng::Random.AbstractRNG,
389350
model::DynamicPPL.Model,
390351
spl::DynamicPPL.Sampler{<:Gibbs};
391-
initial_params=nothing,
352+
initial_params::DynamicPPL.AbstractInitStrategy=DynamicPPL.init_strategy(spl),
392353
kwargs...,
393354
)
394355
alg = spl.alg
@@ -425,7 +386,7 @@ function gibbs_initialstep_recursive(
425386
samplers,
426387
vi,
427388
states=();
428-
initial_params=nothing,
389+
initial_params,
429390
kwargs...,
430391
)
431392
# End recursion
@@ -436,13 +397,6 @@ function gibbs_initialstep_recursive(
436397
varnames, varname_vecs_tail... = varname_vecs
437398
sampler, samplers_tail... = samplers
438399

439-
# Get the initial values for this component sampler.
440-
initial_params_local = if initial_params === nothing
441-
nothing
442-
else
443-
DynamicPPL.subset(vi, varnames)[:]
444-
end
445-
446400
# Construct the conditioned model.
447401
conditioned_model, context = make_conditional(model, varnames, vi)
448402

@@ -453,7 +407,7 @@ function gibbs_initialstep_recursive(
453407
sampler;
454408
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
455409
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
456-
initial_params=initial_params_local,
410+
initial_params=initial_params,
457411
kwargs...,
458412
)
459413
new_vi_local = get_varinfo(new_state)

0 commit comments

Comments
 (0)