From 876ce17de330bcb7b80824a4a9d27d7ee547768f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Jul 2025 00:26:37 +0100 Subject: [PATCH 1/4] Obviously this single commit will make Gibbs work --- src/mcmc/gibbs.jl | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 81281389e..bd925e10b 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -178,12 +178,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. - # TODO(mhauru) Fix accumulation here. In this branch anything that gets - # accumulated just gets discarded with `_`. - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + global_vi = get_global_varinfo(context) + val = global_vi[vn] + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add @@ -210,13 +207,22 @@ function DynamicPPL.tilde_assume( vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) return if is_target_varname(context, vn) + # This branch means that that `sampler` is supposed to handle + # this variable. We can thus use its default behaviour, with + # the 'local' sampler-specific VarInfo. DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + global_vi = get_global_varinfo(context) + val = global_vi[vn] + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else + # This is a variable that isn't handled by any sampler. We can just + # sample a new value and stick it inside the global VarInfo. value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, From 729b5d87bfacc81bf65cb186cb8382e186b71e22 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Jul 2025 01:27:24 +0100 Subject: [PATCH 2/4] Fixes for ESS --- src/mcmc/ess.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index c49b52d3a..0a3090448 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -54,7 +54,7 @@ function AbstractMCMC.step( # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setloglikelihood!!(vi, state.loglikelihood) + vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end @@ -88,6 +88,11 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason # why we had to use the 'del' flag before this was because # SampleFromPrior() wouldn't overwrite existing variables. + # The main problem I'm rather unsure about is ESS-within-Gibbs. The + # current implementation I think makes sure to only resample the variables + # that 'belong' to the current ESS sampler. InitContext on the other hand + # would resample all variables in the model (??) Need to think about this + # carefully. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") @@ -102,13 +107,14 @@ Distributions.mean(p::ESSPrior) = p.μ # Evaluate log-likelihood of proposals. We need this struct because # EllipticalSliceSampling.jl expects a callable struct / a function as its # likelihood. -struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} - ldf::DynamicPPL.LogDensityFunction{M,V} +struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} + ldf::L # Force usage of `getloglikelihood` in inner constructor function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) + # new_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(DynamicPPL.SampleFromPrior())) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) - return new{typeof(model),typeof(varinfo)}(ldf) + return new{typeof(ldf)}(ldf) end end From cc07ca31411227aaa0d114d647690faf0373def0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Jul 2025 01:45:25 +0100 Subject: [PATCH 3/4] Fix HMC call --- src/mcmc/hmc.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index e19f02343..18733f6a8 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -162,7 +162,9 @@ function find_initial_params( # Resample and try again. # NOTE: varinfo has to be linked to make sure this samples in unconstrained space varinfo = last( - DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform()) + DynamicPPL.evaluate_and_sample!!( + rng, model, varinfo, DynamicPPL.SampleFromUniform() + ), ) end From c0c298e1b5f743ae7f131b9710a5ccb2f2cb61db Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Jul 2025 02:16:40 +0100 Subject: [PATCH 4/4] improve some comments --- src/mcmc/ess.jl | 1 - src/mcmc/gibbs.jl | 13 ++++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 0a3090448..feb737a30 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -112,7 +112,6 @@ struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} # Force usage of `getloglikelihood` in inner constructor function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) - # new_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(DynamicPPL.SampleFromPrior())) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) return new{typeof(ldf)}(ldf) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index bd925e10b..265f7dace 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -177,7 +177,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuit the tilde assume if `vn` is present in `context`. + # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it + # will trigger resampling. We may need to do a special kind of observe + # that does not trigger resampling. global_vi = get_global_varinfo(context) val = global_vi[vn] DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) @@ -217,12 +219,17 @@ function DynamicPPL.tilde_assume( # conditioned on, so we can just treat it as an observation. # The only catch is that the value that we need is to be obtained from # the global VarInfo (since the local VarInfo has no knowledge of it). + # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it + # will trigger resampling. We may need to do a special kind of observe + # that does not trigger resampling. global_vi = get_global_varinfo(context) val = global_vi[vn] DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else - # This is a variable that isn't handled by any sampler. We can just - # sample a new value and stick it inside the global VarInfo. + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context,