Skip to content

Commit 31f7331

Browse files
authored
Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628)
* Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments
1 parent b0df6a6 commit 31f7331

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

src/mcmc/ess.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function AbstractMCMC.step(
5454

5555
# update sample and log-likelihood
5656
vi = DynamicPPL.unflatten(vi, sample)
57-
vi = setloglikelihood!!(vi, state.loglikelihood)
57+
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

5959
return Transition(model, vi), vi
6060
end
@@ -88,6 +88,11 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
8888
# p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason
8989
# why we had to use the 'del' flag before this was because
9090
# SampleFromPrior() wouldn't overwrite existing variables.
91+
# The main problem I'm rather unsure about is ESS-within-Gibbs. The
92+
# current implementation I think makes sure to only resample the variables
93+
# that 'belong' to the current ESS sampler. InitContext on the other hand
94+
# would resample all variables in the model (??) Need to think about this
95+
# carefully.
9196
vns = keys(varinfo)
9297
for vn in vns
9398
set_flag!(varinfo, vn, "del")
@@ -102,13 +107,13 @@ Distributions.mean(p::ESSPrior) = p.μ
102107
# Evaluate log-likelihood of proposals. We need this struct because
103108
# EllipticalSliceSampling.jl expects a callable struct / a function as its
104109
# likelihood.
105-
struct ESSLikelihood{M<:Model,V<:AbstractVarInfo}
106-
ldf::DynamicPPL.LogDensityFunction{M,V}
110+
struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction}
111+
ldf::L
107112

108113
# Force usage of `getloglikelihood` in inner constructor
109114
function ESSLikelihood(model::Model, varinfo::AbstractVarInfo)
110115
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo)
111-
return new{typeof(model),typeof(varinfo)}(ldf)
116+
return new{typeof(ldf)}(ldf)
112117
end
113118
end
114119

src/mcmc/gibbs.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,12 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
177177
# Fall back to the default behavior.
178178
DynamicPPL.tilde_assume(child_context, right, vn, vi)
179179
elseif has_conditioned_gibbs(context, vn)
180-
# Short-circuit the tilde assume if `vn` is present in `context`.
181-
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
182-
# accumulated just gets discarded with `_`.
183-
value, _ = DynamicPPL.tilde_assume(
184-
child_context, right, vn, get_global_varinfo(context)
185-
)
186-
value, vi
180+
# TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it
181+
# will trigger resampling. We may need to do a special kind of observe
182+
# that does not trigger resampling.
183+
global_vi = get_global_varinfo(context)
184+
val = global_vi[vn]
185+
DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi)
187186
else
188187
# If the varname has not been conditioned on, nor is it a target variable, its
189188
# presumably a new variable that should be sampled from its prior. We need to add
@@ -210,13 +209,27 @@ function DynamicPPL.tilde_assume(
210209
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
211210

212211
return if is_target_varname(context, vn)
212+
# This branch means that that `sampler` is supposed to handle
213+
# this variable. We can thus use its default behaviour, with
214+
# the 'local' sampler-specific VarInfo.
213215
DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi)
214216
elseif has_conditioned_gibbs(context, vn)
215-
value, _ = DynamicPPL.tilde_assume(
216-
child_context, right, vn, get_global_varinfo(context)
217-
)
218-
value, vi
217+
# This branch means that a different sampler is supposed to handle this
218+
# variable. From the perspective of this sampler, this variable is
219+
# conditioned on, so we can just treat it as an observation.
220+
# The only catch is that the value that we need is to be obtained from
221+
# the global VarInfo (since the local VarInfo has no knowledge of it).
222+
# TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it
223+
# will trigger resampling. We may need to do a special kind of observe
224+
# that does not trigger resampling.
225+
global_vi = get_global_varinfo(context)
226+
val = global_vi[vn]
227+
DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi)
219228
else
229+
# If the varname has not been conditioned on, nor is it a target variable, its
230+
# presumably a new variable that should be sampled from its prior. We need to add
231+
# this new variable to the global `varinfo` of the context, but not to the local one
232+
# being used by the current sampler.
220233
value, new_global_vi = DynamicPPL.tilde_assume(
221234
rng,
222235
child_context,

src/mcmc/hmc.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ function find_initial_params(
162162
# Resample and try again.
163163
# NOTE: varinfo has to be linked to make sure this samples in unconstrained space
164164
varinfo = last(
165-
DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform())
165+
DynamicPPL.evaluate_and_sample!!(
166+
rng, model, varinfo, DynamicPPL.SampleFromUniform()
167+
),
166168
)
167169
end
168170

0 commit comments

Comments
 (0)