Skip to content

Commit c062867

Browse files
mhaurupenelopeysm
andauthored
DPPL 0.37 compat for particle MCMC (#2625)
* Progress in DPPL 0.37 compat for particle MCMC * WIP PMCMC work * Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments * Fixes to ProduceLogLikelihoodAccumulator * Use LogProbAccumulator for ProduceLogLikelihoodAccumulator * use get_conditioned_gibbs --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 7ca59ce commit c062867

File tree

4 files changed

+188
-84
lines changed

4 files changed

+188
-84
lines changed

src/mcmc/ess.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function AbstractMCMC.step(
5454

5555
# update sample and log-likelihood
5656
vi = DynamicPPL.unflatten(vi, sample)
57-
vi = setloglikelihood!!(vi, state.loglikelihood)
57+
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

5959
return Transition(model, vi), vi
6060
end
@@ -88,6 +88,11 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
8888
# p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason
8989
# why we had to use the 'del' flag before this was because
9090
# SampleFromPrior() wouldn't overwrite existing variables.
91+
# The main problem I'm rather unsure about is ESS-within-Gibbs. The
92+
# current implementation I think makes sure to only resample the variables
93+
# that 'belong' to the current ESS sampler. InitContext on the other hand
94+
# would resample all variables in the model (??) Need to think about this
95+
# carefully.
9196
vns = keys(varinfo)
9297
for vn in vns
9398
set_flag!(varinfo, vn, "del")
@@ -102,13 +107,13 @@ Distributions.mean(p::ESSPrior) = p.μ
102107
# Evaluate log-likelihood of proposals. We need this struct because
103108
# EllipticalSliceSampling.jl expects a callable struct / a function as its
104109
# likelihood.
105-
struct ESSLikelihood{M<:Model,V<:AbstractVarInfo}
106-
ldf::DynamicPPL.LogDensityFunction{M,V}
110+
struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction}
111+
ldf::L
107112

108113
# Force usage of `getloglikelihood` in inner constructor
109114
function ESSLikelihood(model::Model, varinfo::AbstractVarInfo)
110115
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo)
111-
return new{typeof(model),typeof(varinfo)}(ldf)
116+
return new{typeof(ldf)}(ldf)
112117
end
113118
end
114119

src/mcmc/gibbs.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,15 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
177177
# Fall back to the default behavior.
178178
DynamicPPL.tilde_assume(child_context, right, vn, vi)
179179
elseif has_conditioned_gibbs(context, vn)
180-
# Short-circuit the tilde assume if `vn` is present in `context`.
181-
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
182-
# accumulated just gets discarded with `_`.
183-
value, _ = DynamicPPL.tilde_assume(
184-
child_context, right, vn, get_global_varinfo(context)
185-
)
186-
value, vi
180+
# This branch means that a different sampler is supposed to handle this
181+
# variable. From the perspective of this sampler, this variable is
182+
# conditioned on, so we can just treat it as an observation.
183+
# The only catch is that the value that we need is to be obtained from
184+
# the global VarInfo (since the local VarInfo has no knowledge of it).
185+
# Note that tilde_observe!! will trigger resampling in particle methods
186+
# for variables that are handled by other Gibbs component samplers.
187+
val = get_conditioned_gibbs(context, vn)
188+
DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi)
187189
else
188190
# If the varname has not been conditioned on, nor is it a target variable, its
189191
# presumably a new variable that should be sampled from its prior. We need to add
@@ -210,13 +212,25 @@ function DynamicPPL.tilde_assume(
210212
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
211213

212214
return if is_target_varname(context, vn)
215+
# This branch means that that `sampler` is supposed to handle
216+
# this variable. We can thus use its default behaviour, with
217+
# the 'local' sampler-specific VarInfo.
213218
DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi)
214219
elseif has_conditioned_gibbs(context, vn)
215-
value, _ = DynamicPPL.tilde_assume(
216-
child_context, right, vn, get_global_varinfo(context)
217-
)
218-
value, vi
220+
# This branch means that a different sampler is supposed to handle this
221+
# variable. From the perspective of this sampler, this variable is
222+
# conditioned on, so we can just treat it as an observation.
223+
# The only catch is that the value that we need is to be obtained from
224+
# the global VarInfo (since the local VarInfo has no knowledge of it).
225+
# Note that tilde_observe!! will trigger resampling in particle methods
226+
# for variables that are handled by other Gibbs component samplers.
227+
val = get_conditioned_gibbs(context, vn)
228+
DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi)
219229
else
230+
# If the varname has not been conditioned on, nor is it a target variable, its
231+
# presumably a new variable that should be sampled from its prior. We need to add
232+
# this new variable to the global `varinfo` of the context, but not to the local one
233+
# being used by the current sampler.
220234
value, new_global_vi = DynamicPPL.tilde_assume(
221235
rng,
222236
child_context,

src/mcmc/hmc.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ function find_initial_params(
162162
# Resample and try again.
163163
# NOTE: varinfo has to be linked to make sure this samples in unconstrained space
164164
varinfo = last(
165-
DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform())
165+
DynamicPPL.evaluate_and_sample!!(
166+
rng, model, varinfo, DynamicPPL.SampleFromUniform()
167+
),
166168
)
167169
end
168170

0 commit comments

Comments
 (0)