From 8b4e7f270b3e7483ec313adf765209662cf3a6ac Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 18 Aug 2025 18:47:32 +0100 Subject: [PATCH 1/4] [skip ci] Enable keyword arguments for particle methods --- src/mcmc/particle_mcmc.jl | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ab2add975..5318aa34d 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -36,12 +36,14 @@ function unset_all_del!(vi::AbstractVarInfo) return nothing end -struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: - AdvancedPS.AbstractGenericModel +struct TracedModel{ + S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple +} <: AdvancedPS.AbstractGenericModel model::M sampler::S varinfo::V - evaluator::E + fargs::T + kwargs::NT end function TracedModel( @@ -53,13 +55,8 @@ function TracedModel( spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) spl_model = DynamicPPL.contextualize(model, spl_context) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) - if kwargs !== nothing && !isempty(kwargs) - error( - "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", - ) - end - evaluator = (spl_model.f, args...) - return TracedModel(spl_model, sampler, varinfo, evaluator) + fargs = (spl_model.f, args...) + return TracedModel(spl_model, sampler, varinfo, fargs, kwargs) end function AdvancedPS.advance!( @@ -91,9 +88,9 @@ function AdvancedPS.reset_model(trace::TracedModel) return trace end -function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) +function Libtask.TapedTask(taped_globals, model::TracedModel) return Libtask.TapedTask( - taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... + taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs... ) end @@ -191,6 +188,9 @@ function DynamicPPL.initialstep( nparticles::Int, kwargs..., ) + if !isempty(model.defaults) + @warn "The use of particle methods for models with keyword arguments requires special care. Please see for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for." + end # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) set_all_del!(vi) @@ -316,6 +316,9 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) + if !isempty(model.defaults) + @warn "The use of particle methods for models with keyword arguments requires special care. Please see for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for." + end vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep set_all_del!(vi) @@ -586,3 +589,11 @@ function Libtask.might_produce( return true end Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true +# This method deals with models that have keyword arguments, although it is alone not +# sufficient to make Libtask fully work with keyword arguments. In here, the second argument +# to Core.kwcall here is `model.f`. +function Libtask.might_produce( + ::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,<:Any,<:DynamicPPL.Model,Vararg}} +) + return true +end From dfd23c229b63bcb8ad214f87f3e721a9301716ba Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 19 Aug 2025 15:18:30 +0100 Subject: [PATCH 2/4] Hint at @might_produce --- src/mcmc/particle_mcmc.jl | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 5318aa34d..be10da9ba 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -180,6 +180,31 @@ function AbstractMCMC.sample( end end +function check_model_kwargs(model::DynamicPPL.Model) + if !isempty(model.defaults) + # If there are keyword arguments, we need to check that the user has + # accounted for this by overloading `might_produce`. + might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f))) + if !might_produce + io = IOBuffer() + ctx = IOContext(io, :color => true) + print( + ctx, + "Models with keyword arguments need special treatment to be used" * + " with particle methods. Please run:\n\n", + ) + printstyled( + ctx, + " using Libtask; Libtask.@might_produce($(model.f))"; + bold=true, + color=:blue, + ) + print(ctx, "\n\nbefore sampling from this model with particle methods.\n") + error(String(take!(io))) + end + end +end + function DynamicPPL.initialstep( rng::AbstractRNG, model::AbstractModel, @@ -188,9 +213,7 @@ function DynamicPPL.initialstep( nparticles::Int, kwargs..., ) - if !isempty(model.defaults) - @warn "The use of particle methods for models with keyword arguments requires special care. Please see for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for." - end + check_model_kwargs(model) # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) set_all_del!(vi) @@ -316,9 +339,7 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - if !isempty(model.defaults) - @warn "The use of particle methods for models with keyword arguments requires special care. Please see for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for." - end + check_model_kwargs(model) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep set_all_del!(vi) From 160b7d3f8718248b16f0deb67982a666d65e32b6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 19 Aug 2025 15:21:06 +0100 Subject: [PATCH 3/4] Add changelog, bump patch --- HISTORY.md | 16 ++++++++++++++++ Project.toml | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d7334c60b..6f07b0e66 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# 0.40.3 + +SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce. +For example, if the model is + +```julia +@model foo(x; y) = a ~ Normal(x, y) +``` + +then before samping from this with SMC or PG, you will have to run + +```julia +using Libtask; +Libtask.@might_produce(foo); +``` + # 0.40.2 `sample(model, NUTS(), N; verbose=false)` now suppresses the 'initial step size' message. diff --git a/Project.toml b/Project.toml index 422284095..3a6b7a4eb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.2" +version = "0.40.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -67,7 +67,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.3" +Libtask = "0.9.5" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" From d80a8defc7e9ea50fe84d153a9702cba25651dde Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 19 Aug 2025 15:26:42 +0100 Subject: [PATCH 4/4] Why not use the macro for everything else --- src/mcmc/particle_mcmc.jl | 36 ++++++++---------------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index be10da9ba..3f14f8e77 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -576,7 +576,7 @@ end # details of the compiler, we set a bunch of methods as might_produce = true. We start with # adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the # call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +Libtask.@might_produce(DynamicPPL.accloglikelihood!!) function Libtask.might_produce( ::Type{ <:Tuple{ @@ -588,33 +588,13 @@ function Libtask.might_produce( ) return true end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} -) - return true -end -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +Libtask.@might_produce(DynamicPPL.accumulate_observe!!) +Libtask.@might_produce(DynamicPPL.tilde_observe!!) # Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? # That's the only thing that makes tilde_assume calls result in tilde_observe calls. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} -) - return true -end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} -) - return true -end +Libtask.@might_produce(DynamicPPL.tilde_assume!!) +Libtask.@might_produce(DynamicPPL.tilde_assume) +Libtask.@might_produce(DynamicPPL.evaluate!!) +Libtask.@might_produce(DynamicPPL.evaluate_threadsafe!!) +Libtask.@might_produce(DynamicPPL.evaluate_threadunsafe!!) Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true -# This method deals with models that have keyword arguments, although it is alone not -# sufficient to make Libtask fully work with keyword arguments. In here, the second argument -# to Core.kwcall here is `model.f`. -function Libtask.might_produce( - ::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,<:Any,<:DynamicPPL.Model,Vararg}} -) - return true -end