From b0df6a66a0b5ce28462fabc4380e3c36647235cb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 22 Jul 2025 15:12:25 +0100 Subject: [PATCH 1/6] Progress in DPPL 0.37 compat for particle MCMC --- src/mcmc/particle_mcmc.jl | 153 ++++++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 54 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 6cf8fc315..975754d27 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -33,10 +33,14 @@ end function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) - # Make sure we load/reset the rng in the new replaying mechanism - trace = Accessors.@set trace.model.f.varinfo = DynamicPPL.increment_num_produce!!( - trace.model.f.varinfo + # We want to increment num produce for the VarInfo stored in the trace. The trace is + # mutable, so we create a new model with the incremented VarInfo and set it in the trace + model = trace.model + model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!( + model.f.varinfo ) + trace.model = model + # Make sure we load/reset the rng in the new replaying mechanism isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) if score === nothing @@ -55,10 +59,6 @@ function AdvancedPS.reset_model(trace::TracedModel) return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) end -function AdvancedPS.reset_logprob!(trace::TracedModel) - return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo) -end - function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) return Libtask.TapedTask( taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... @@ -390,78 +390,124 @@ function DynamicPPL.use_threadsafe_eval( return false end -function trace_local_varinfo_maybe(varinfo) - try - trace = Libtask.get_taped_globals(Any).other - return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo +""" + get_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Get the `Trace` local varinfo if one exists. + +If executed within a `TapedTask`, return the `varinfo` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo) + trace = try + Libtask.get_taped_globals(Any).other catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return varinfo - else - rethrow(e) - end + e == KeyError(:task_variable) ? nothing : rethrow(e) end + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo end -function trace_local_rng_maybe(rng::Random.AbstractRNG) - try - return Libtask.get_taped_globals(Any).rng +""" + get_trace_local_varinfo_maybe(rng::Random.AbstractRNG) + +Get the `Trace` local rng if one exists. + +If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_rng_maybe(rng::Random.AbstractRNG) + return try + Libtask.get_taped_globals(Any).rng catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return rng - else - rethrow(e) - end + e == KeyError(:task_variable) ? rng : rethrow(e) end end -# TODO(DPPL0.37/penelopeysm) The whole tilde pipeline for particle MCMC needs to be -# thoroughly fixed. +""" + set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`. + +If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the +task. Otherwise do nothing. +""" +function set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + # TODO(mhauru) This should be done in a try-catch block, as in the commented out code. + # However, Libtask currently can't handle this block. + trace = #try + Libtask.get_taped_globals(Any).other + # catch e + # e == KeyError(:task_variable) ? nothing : rethrow(e) + # end + if trace !== nothing + model = trace.model + model = Accessors.@set model.f.varinfo = vi + trace.model = model + end + return nothing +end + function DynamicPPL.assume( - rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo + rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - vi = trace_local_varinfo_maybe(_vi) - trng = trace_local_rng_maybe(rng) + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + trng = get_trace_local_rng_maybe(rng) if ~haskey(vi, vn) r = rand(trng, dist) - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) + # TODO(mhauru): + # The below is the only line that differs from assume called on SampleFromPrior. + # Could we just call assume on SampleFromPrior and then `setorder!!` after that? vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) else r = vi[vn] end - # TODO: call accumulate_assume?! + + # TODO(mhauru) This get/set business is awful. + old_logp = DynamicPPL.getlogprior(vi) + vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) + vi = DynamicPPL.setlogprior!!(vi, old_logp) + + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end return r, vi end -# TODO(mhauru) Fix this. -# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) -# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. -# return logpdf(dist, value), trace_local_varinfo_maybe(vi) -# end - -function DynamicPPL.acclogp!!( - context::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, - varinfo::AbstractVarInfo, - logp, +function DynamicPPL.tilde_observe!!( + ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi ) - varinfo_trace = trace_local_varinfo_maybe(varinfo) - return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) -end + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + # TODO(mhauru) This get/set business is awful. + old_logp = DynamicPPL.getloglikelihood(vi) + left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) + new_loglikelihood = DynamicPPL.getloglikelihood(vi) - old_logp + vi = DynamicPPL.setloglikelihood!!(vi, old_logp) + + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end -# TODO(mhauru) Fix this. -# function DynamicPPL.acclogp_observe!!( -# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -# ) -# Libtask.produce(logp) -# return trace_local_varinfo_maybe(varinfo) -# end + Libtask.produce(new_loglikelihood) + return left, vi +end # Convenient constructor function AdvancedPS.Trace( @@ -483,7 +529,6 @@ end # called on `:invoke` expressions rather than `:call`s, but since those are implementation # details of the compiler, we set a bunch of methods as might_produce = true. We start with # `acclogp_observe!!` which is what calls `produce` and go up the call stack. -# Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( From 8cb2e19ae42db0d657f6808f77d044a67b3b7221 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 23 Jul 2025 11:40:31 +0100 Subject: [PATCH 2/6] WIP PMCMC work --- src/mcmc/particle_mcmc.jl | 144 +++++++++++++++++++++++++++++++------- 1 file changed, 118 insertions(+), 26 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 975754d27..c1cc90aa5 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -43,11 +43,7 @@ function AdvancedPS.advance!( # Make sure we load/reset the rng in the new replaying mechanism isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) - if score === nothing - return nothing - else - return score + DynamicPPL.getlogjoint(trace.model.f.varinfo) - end + return score end function AdvancedPS.delete_retained!(trace::TracedModel) @@ -114,11 +110,7 @@ end function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. lp = DynamicPPL.getlogjoint(vi) - return SMCTransition(theta, lp, weight) end @@ -293,11 +285,7 @@ varinfo(state::PGState) = state.vi function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. lp = DynamicPPL.getlogjoint(vi) - return PGTransition(theta, lp, logevidence) end @@ -316,6 +304,7 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @@ -471,10 +460,7 @@ function DynamicPPL.assume( r = vi[vn] end - # TODO(mhauru) This get/set business is awful. - old_logp = DynamicPPL.getlogprior(vi) vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) - vi = DynamicPPL.setlogprior!!(vi, old_logp) # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, @@ -492,20 +478,14 @@ function DynamicPPL.tilde_observe!!( vi = get_trace_local_varinfo_maybe(vi) using_local_vi = objectid(vi) == arg_vi_id - # TODO(mhauru) This get/set business is awful. - old_logp = DynamicPPL.getloglikelihood(vi) left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) - new_loglikelihood = DynamicPPL.getloglikelihood(vi) - old_logp - vi = DynamicPPL.setloglikelihood!!(vi, old_logp) # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, # hence this. - if !using_local_vi - set_trace_local_varinfo_maybe(vi) - end - - Libtask.produce(new_loglikelihood) + # if !using_local_vi + # set_trace_local_varinfo_maybe(vi) + # end return left, vi end @@ -528,7 +508,25 @@ end # of these won't be needed, because of inlining and the fact that `might_produce` is only # called on `:invoke` expressions rather than `:call`s, but since those are implementation # details of the compiler, we set a bunch of methods as might_produce = true. We start with -# `acclogp_observe!!` which is what calls `produce` and go up the call stack. +# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the +# call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{ + <:Tuple{ + typeof(Base.:+), + ProduceLogLikelihoodAccumulator, + DynamicPPL.LogLikelihoodAccumulator, + }, + }, +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} +) + return true +end Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( @@ -542,3 +540,97 @@ function Libtask.might_produce( return true end Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true + +""" + ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase. + +# Fields +$(TYPEDFIELDS) +""" +struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.AbstractAccumulator + "the scalar log likelihood value" + logp::T +end + +""" + ProduceLogLikelihoodAccumulator{T}() + +Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero. +""" +ProduceLogLikelihoodAccumulator{T}() where {T<:Real} = + ProduceLogLikelihoodAccumulator(zero(T)) +function ProduceLogLikelihoodAccumulator() + return ProduceLogLikelihoodAccumulator{DynamicPPL.LogProbType}() +end + +Base.copy(acc::ProduceLogLikelihoodAccumulator) = acc + +function Base.show(io::IO, acc::ProduceLogLikelihoodAccumulator) + return print(io, "ProduceLogLikelihoodAccumulator($(repr(acc.logp)))") +end + +function Base.:(==)( + acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator +) + return acc1.logp == acc2.logp +end + +function Base.isequal( + acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator +) + return isequal(acc1.logp, acc2.logp) +end + +function Base.hash(acc::ProduceLogLikelihoodAccumulator, h::UInt) + return hash((ProduceLogLikelihoodAccumulator, acc.logp), h) +end + +# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two +# can be used in a given VarInfo. +DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood + +function DynamicPPL.split(::ProduceLogLikelihoodAccumulator{T}) where {T} + return ProduceLogLikelihoodAccumulator(zero(T)) +end + +function DynamicPPL.combine( + acc::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator +) + return ProduceLogLikelihoodAccumulator(acc.logp + acc2.logp) +end + +function Base.:+( + acc1::ProduceLogLikelihoodAccumulator, acc2::DynamicPPL.LogLikelihoodAccumulator +) + Libtask.produce(acc2.logp) + return ProduceLogLikelihoodAccumulator(acc1.logp + acc2.logp) +end + +function Base.zero(acc::ProduceLogLikelihoodAccumulator) + return ProduceLogLikelihoodAccumulator(zero(acc.logp)) +end + +function DynamicPPL.accumulate_assume!!( + acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right +) + return acc +end +function DynamicPPL.accumulate_observe!!( + acc::ProduceLogLikelihoodAccumulator, right, left, vn +) + return acc + + DynamicPPL.LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) +end + +function Base.convert( + ::Type{ProduceLogLikelihoodAccumulator{T}}, acc::ProduceLogLikelihoodAccumulator +) where {T} + return ProduceLogLikelihoodAccumulator(convert(T, acc.logp)) +end +function DynamicPPL.convert_eltype( + ::Type{T}, acc::ProduceLogLikelihoodAccumulator +) where {T} + return ProduceLogLikelihoodAccumulator(convert(T, acc.logp)) +end From 31f73310022bd15329d6f3810abffc95ee3a295d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Jul 2025 15:13:18 +0100 Subject: [PATCH 3/6] Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments --- src/mcmc/ess.jl | 13 +++++++++---- src/mcmc/gibbs.jl | 35 ++++++++++++++++++++++++----------- src/mcmc/hmc.jl | 4 +++- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index c49b52d3a..feb737a30 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -54,7 +54,7 @@ function AbstractMCMC.step( # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setloglikelihood!!(vi, state.loglikelihood) + vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end @@ -88,6 +88,11 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason # why we had to use the 'del' flag before this was because # SampleFromPrior() wouldn't overwrite existing variables. + # The main problem I'm rather unsure about is ESS-within-Gibbs. The + # current implementation I think makes sure to only resample the variables + # that 'belong' to the current ESS sampler. InitContext on the other hand + # would resample all variables in the model (??) Need to think about this + # carefully. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") @@ -102,13 +107,13 @@ Distributions.mean(p::ESSPrior) = p.μ # Evaluate log-likelihood of proposals. We need this struct because # EllipticalSliceSampling.jl expects a callable struct / a function as its # likelihood. -struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} - ldf::DynamicPPL.LogDensityFunction{M,V} +struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} + ldf::L # Force usage of `getloglikelihood` in inner constructor function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) - return new{typeof(model),typeof(varinfo)}(ldf) + return new{typeof(ldf)}(ldf) end end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 81281389e..265f7dace 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -177,13 +177,12 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuit the tilde assume if `vn` is present in `context`. - # TODO(mhauru) Fix accumulation here. In this branch anything that gets - # accumulated just gets discarded with `_`. - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it + # will trigger resampling. We may need to do a special kind of observe + # that does not trigger resampling. + global_vi = get_global_varinfo(context) + val = global_vi[vn] + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add @@ -210,13 +209,27 @@ function DynamicPPL.tilde_assume( vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) return if is_target_varname(context, vn) + # This branch means that that `sampler` is supposed to handle + # this variable. We can thus use its default behaviour, with + # the 'local' sampler-specific VarInfo. DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it + # will trigger resampling. We may need to do a special kind of observe + # that does not trigger resampling. + global_vi = get_global_varinfo(context) + val = global_vi[vn] + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index e19f02343..18733f6a8 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -162,7 +162,9 @@ function find_initial_params( # Resample and try again. # NOTE: varinfo has to be linked to make sure this samples in unconstrained space varinfo = last( - DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform()) + DynamicPPL.evaluate_and_sample!!( + rng, model, varinfo, DynamicPPL.SampleFromUniform() + ), ) end From bdcd72f1e4f60c5542578f8bf4fc57e8307eef60 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 23 Jul 2025 17:09:34 +0100 Subject: [PATCH 4/6] Fixes to ProduceLogLikelihoodAccumulator --- src/mcmc/particle_mcmc.jl | 87 +++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 41 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index c1cc90aa5..bc294c6f1 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -175,6 +175,7 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) vi = DynamicPPL.resetlogp!!(vi) @@ -483,9 +484,9 @@ function DynamicPPL.tilde_observe!!( # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, # hence this. - # if !using_local_vi - # set_trace_local_varinfo_maybe(vi) - # end + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end return left, vi end @@ -504,47 +505,10 @@ function AdvancedPS.Trace( return newtrace end -# We need to tell Libtask which calls may have `produce` calls within them. In practice most -# of these won't be needed, because of inlining and the fact that `might_produce` is only -# called on `:invoke` expressions rather than `:call`s, but since those are implementation -# details of the compiler, we set a bunch of methods as might_produce = true. We start with -# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the -# call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true -function Libtask.might_produce( - ::Type{ - <:Tuple{ - typeof(Base.:+), - ProduceLogLikelihoodAccumulator, - DynamicPPL.LogLikelihoodAccumulator, - }, - }, -) - return true -end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} -) - return true -end -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} -) - return true -end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} -) - return true -end -Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true - """ ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator -Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase. +Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value. # Fields $(TYPEDFIELDS) @@ -634,3 +598,44 @@ function DynamicPPL.convert_eltype( ) where {T} return ProduceLogLikelihoodAccumulator(convert(T, acc.logp)) end + +# We need to tell Libtask which calls may have `produce` calls within them. In practice most +# of these won't be needed, because of inlining and the fact that `might_produce` is only +# called on `:invoke` expressions rather than `:call`s, but since those are implementation +# details of the compiler, we set a bunch of methods as might_produce = true. We start with +# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the +# call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{ + <:Tuple{ + typeof(Base.:+), + ProduceLogLikelihoodAccumulator, + DynamicPPL.LogLikelihoodAccumulator, + }, + }, +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} +) + return true +end +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? +# That's the only thing that makes tilde_assume calls result in tilde_observe calls. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} +) + return true +end +Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true From 0dde8bc829ea62a4a9ce8ff506d9635398a27bd0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Jul 2025 15:47:34 +0100 Subject: [PATCH 5/6] Use LogProbAccumulator for ProduceLogLikelihoodAccumulator --- src/mcmc/particle_mcmc.jl | 73 ++++----------------------------------- 1 file changed, 7 insertions(+), 66 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index bc294c6f1..17feb18c1 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -513,67 +513,20 @@ Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change o # Fields $(TYPEDFIELDS) """ -struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.AbstractAccumulator +struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} "the scalar log likelihood value" logp::T end -""" - ProduceLogLikelihoodAccumulator{T}() - -Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero. -""" -ProduceLogLikelihoodAccumulator{T}() where {T<:Real} = - ProduceLogLikelihoodAccumulator(zero(T)) -function ProduceLogLikelihoodAccumulator() - return ProduceLogLikelihoodAccumulator{DynamicPPL.LogProbType}() -end - -Base.copy(acc::ProduceLogLikelihoodAccumulator) = acc - -function Base.show(io::IO, acc::ProduceLogLikelihoodAccumulator) - return print(io, "ProduceLogLikelihoodAccumulator($(repr(acc.logp)))") -end - -function Base.:(==)( - acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator -) - return acc1.logp == acc2.logp -end - -function Base.isequal( - acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator -) - return isequal(acc1.logp, acc2.logp) -end - -function Base.hash(acc::ProduceLogLikelihoodAccumulator, h::UInt) - return hash((ProduceLogLikelihoodAccumulator, acc.logp), h) -end - # Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two # can be used in a given VarInfo. DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood +DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp -function DynamicPPL.split(::ProduceLogLikelihoodAccumulator{T}) where {T} - return ProduceLogLikelihoodAccumulator(zero(T)) -end - -function DynamicPPL.combine( - acc::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator -) - return ProduceLogLikelihoodAccumulator(acc.logp + acc2.logp) -end - -function Base.:+( - acc1::ProduceLogLikelihoodAccumulator, acc2::DynamicPPL.LogLikelihoodAccumulator -) - Libtask.produce(acc2.logp) - return ProduceLogLikelihoodAccumulator(acc1.logp + acc2.logp) -end - -function Base.zero(acc::ProduceLogLikelihoodAccumulator) - return ProduceLogLikelihoodAccumulator(zero(acc.logp)) +function DynamicPPL.acclogp(acc1::ProduceLogLikelihoodAccumulator, val) + # The below line is the only difference from `LogLikelihoodAccumulator`. + Libtask.produce(val) + return ProduceLogLikelihoodAccumulator(acc1.logp + val) end function DynamicPPL.accumulate_assume!!( @@ -584,19 +537,7 @@ end function DynamicPPL.accumulate_observe!!( acc::ProduceLogLikelihoodAccumulator, right, left, vn ) - return acc + - DynamicPPL.LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) -end - -function Base.convert( - ::Type{ProduceLogLikelihoodAccumulator{T}}, acc::ProduceLogLikelihoodAccumulator -) where {T} - return ProduceLogLikelihoodAccumulator(convert(T, acc.logp)) -end -function DynamicPPL.convert_eltype( - ::Type{T}, acc::ProduceLogLikelihoodAccumulator -) where {T} - return ProduceLogLikelihoodAccumulator(convert(T, acc.logp)) + return DynamicPPL.acclogp(acc, Distributions.loglikelihood(right, left)) end # We need to tell Libtask which calls may have `produce` calls within them. In practice most From 6d6fac8ddf7e8b232fdfd7f635c5d75d43e3c510 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 31 Jul 2025 12:05:50 +0100 Subject: [PATCH 6/6] use get_conditioned_gibbs --- src/mcmc/gibbs.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 265f7dace..58db29789 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -177,11 +177,14 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it - # will trigger resampling. We may need to do a special kind of observe - # that does not trigger resampling. - global_vi = get_global_varinfo(context) - val = global_vi[vn] + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its @@ -219,11 +222,9 @@ function DynamicPPL.tilde_assume( # conditioned on, so we can just treat it as an observation. # The only catch is that the value that we need is to be obtained from # the global VarInfo (since the local VarInfo has no knowledge of it). - # TODO(DPPL0.37/penelopeysm): Unsure if this is bad for SMC as it - # will trigger resampling. We may need to do a special kind of observe - # that does not trigger resampling. - global_vi = get_global_varinfo(context) - val = global_vi[vn] + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its