diff --git a/HISTORY.md b/HISTORY.md index d7334c60b..6f07b0e66 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# 0.40.3 + +SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce. +For example, if the model is + +```julia +@model foo(x; y) = a ~ Normal(x, y) +``` + +then before samping from this with SMC or PG, you will have to run + +```julia +using Libtask; +Libtask.@might_produce(foo); +``` + # 0.40.2 `sample(model, NUTS(), N; verbose=false)` now suppresses the 'initial step size' message. diff --git a/Project.toml b/Project.toml index 422284095..3a6b7a4eb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.2" +version = "0.40.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -67,7 +67,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.3" +Libtask = "0.9.5" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ab2add975..3f14f8e77 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -36,12 +36,14 @@ function unset_all_del!(vi::AbstractVarInfo) return nothing end -struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: - AdvancedPS.AbstractGenericModel +struct TracedModel{ + S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple +} <: AdvancedPS.AbstractGenericModel model::M sampler::S varinfo::V - evaluator::E + fargs::T + kwargs::NT end function TracedModel( @@ -53,13 +55,8 @@ function TracedModel( spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) spl_model = DynamicPPL.contextualize(model, spl_context) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) - if kwargs !== nothing && !isempty(kwargs) - error( - "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", - ) - end - evaluator = (spl_model.f, args...) - return TracedModel(spl_model, sampler, varinfo, evaluator) + fargs = (spl_model.f, args...) + return TracedModel(spl_model, sampler, varinfo, fargs, kwargs) end function AdvancedPS.advance!( @@ -91,9 +88,9 @@ function AdvancedPS.reset_model(trace::TracedModel) return trace end -function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) +function Libtask.TapedTask(taped_globals, model::TracedModel) return Libtask.TapedTask( - taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... + taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs... ) end @@ -183,6 +180,31 @@ function AbstractMCMC.sample( end end +function check_model_kwargs(model::DynamicPPL.Model) + if !isempty(model.defaults) + # If there are keyword arguments, we need to check that the user has + # accounted for this by overloading `might_produce`. + might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f))) + if !might_produce + io = IOBuffer() + ctx = IOContext(io, :color => true) + print( + ctx, + "Models with keyword arguments need special treatment to be used" * + " with particle methods. Please run:\n\n", + ) + printstyled( + ctx, + " using Libtask; Libtask.@might_produce($(model.f))"; + bold=true, + color=:blue, + ) + print(ctx, "\n\nbefore sampling from this model with particle methods.\n") + error(String(take!(io))) + end + end +end + function DynamicPPL.initialstep( rng::AbstractRNG, model::AbstractModel, @@ -191,6 +213,7 @@ function DynamicPPL.initialstep( nparticles::Int, kwargs..., ) + check_model_kwargs(model) # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) set_all_del!(vi) @@ -316,6 +339,7 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) + check_model_kwargs(model) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep set_all_del!(vi) @@ -552,7 +576,7 @@ end # details of the compiler, we set a bunch of methods as might_produce = true. We start with # adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the # call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +Libtask.@might_produce(DynamicPPL.accloglikelihood!!) function Libtask.might_produce( ::Type{ <:Tuple{ @@ -564,25 +588,13 @@ function Libtask.might_produce( ) return true end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} -) - return true -end -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +Libtask.@might_produce(DynamicPPL.accumulate_observe!!) +Libtask.@might_produce(DynamicPPL.tilde_observe!!) # Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? # That's the only thing that makes tilde_assume calls result in tilde_observe calls. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} -) - return true -end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} -) - return true -end +Libtask.@might_produce(DynamicPPL.tilde_assume!!) +Libtask.@might_produce(DynamicPPL.tilde_assume) +Libtask.@might_produce(DynamicPPL.evaluate!!) +Libtask.@might_produce(DynamicPPL.evaluate_threadsafe!!) +Libtask.@might_produce(DynamicPPL.evaluate_threadunsafe!!) Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true