diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index c49b52d3a..feb737a30 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -54,7 +54,7 @@ function AbstractMCMC.step( # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setloglikelihood!!(vi, state.loglikelihood) + vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end @@ -88,6 +88,11 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason # why we had to use the 'del' flag before this was because # SampleFromPrior() wouldn't overwrite existing variables. + # The main problem I'm rather unsure about is ESS-within-Gibbs. The + # current implementation I think makes sure to only resample the variables + # that 'belong' to the current ESS sampler. InitContext on the other hand + # would resample all variables in the model (??) Need to think about this + # carefully. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") @@ -102,13 +107,13 @@ Distributions.mean(p::ESSPrior) = p.μ # Evaluate log-likelihood of proposals. We need this struct because # EllipticalSliceSampling.jl expects a callable struct / a function as its # likelihood. -struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} - ldf::DynamicPPL.LogDensityFunction{M,V} +struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} + ldf::L # Force usage of `getloglikelihood` in inner constructor function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) - return new{typeof(model),typeof(varinfo)}(ldf) + return new{typeof(ldf)}(ldf) end end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 81281389e..265f7dace 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -177,13 +177,12 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuit the tilde assume if `vn` is present in `context`. - # TODO(mhauru) Fix accumulation here. In this branch anything that gets - # accumulated just gets discarded with `_`. - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it + # will trigger resampling. We may need to do a special kind of observe + # that does not trigger resampling. + global_vi = get_global_varinfo(context) + val = global_vi[vn] + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add @@ -210,13 +209,27 @@ function DynamicPPL.tilde_assume( vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) return if is_target_varname(context, vn) + # This branch means that that `sampler` is supposed to handle + # this variable. We can thus use its default behaviour, with + # the 'local' sampler-specific VarInfo. DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it + # will trigger resampling. We may need to do a special kind of observe + # that does not trigger resampling. + global_vi = get_global_varinfo(context) + val = global_vi[vn] + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index e19f02343..18733f6a8 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -162,7 +162,9 @@ function find_initial_params( # Resample and try again. # NOTE: varinfo has to be linked to make sure this samples in unconstrained space varinfo = last( - DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform()) + DynamicPPL.evaluate_and_sample!!( + rng, model, varinfo, DynamicPPL.SampleFromUniform() + ), ) end