Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# 0.40.3

SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce.
For example, if the model is

```julia
@model foo(x; y) = a ~ Normal(x, y)
```

then before samping from this with SMC or PG, you will have to run

```julia
using Libtask;
Libtask.@might_produce(foo);
```

# 0.40.2

`sample(model, NUTS(), N; verbose=false)` now suppresses the 'initial step size' message.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.40.2"
version = "0.40.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -67,7 +67,7 @@ DynamicHMC = "3.4"
DynamicPPL = "0.37"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.3"
Libtask = "0.9.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
Expand Down
76 changes: 44 additions & 32 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ function unset_all_del!(vi::AbstractVarInfo)
return nothing
end

struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
AdvancedPS.AbstractGenericModel
struct TracedModel{
S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple
} <: AdvancedPS.AbstractGenericModel
model::M
sampler::S
varinfo::V
evaluator::E
fargs::T
kwargs::NT
end

function TracedModel(
Expand All @@ -53,13 +55,8 @@ function TracedModel(
spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context)
spl_model = DynamicPPL.contextualize(model, spl_context)
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo)
if kwargs !== nothing && !isempty(kwargs)
error(
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
)
end
evaluator = (spl_model.f, args...)
return TracedModel(spl_model, sampler, varinfo, evaluator)
fargs = (spl_model.f, args...)
return TracedModel(spl_model, sampler, varinfo, fargs, kwargs)
end

function AdvancedPS.advance!(
Expand Down Expand Up @@ -91,9 +88,9 @@ function AdvancedPS.reset_model(trace::TracedModel)
return trace
end

function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
function Libtask.TapedTask(taped_globals, model::TracedModel)
return Libtask.TapedTask(
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs...
)
end

Expand Down Expand Up @@ -183,6 +180,31 @@ function AbstractMCMC.sample(
end
end

function check_model_kwargs(model::DynamicPPL.Model)
if !isempty(model.defaults)
# If there are keyword arguments, we need to check that the user has
# accounted for this by overloading `might_produce`.
might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f)))
if !might_produce
io = IOBuffer()
ctx = IOContext(io, :color => true)
print(
ctx,
"Models with keyword arguments need special treatment to be used" *
" with particle methods. Please run:\n\n",
)
printstyled(
ctx,
" using Libtask; Libtask.@might_produce($(model.f))";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be open to re-exporting this, or something similar, from Turing. However I think it's not hugely important.

bold=true,
color=:blue,
)
print(ctx, "\n\nbefore sampling from this model with particle methods.\n")
error(String(take!(io)))
end
end
end

function DynamicPPL.initialstep(
rng::AbstractRNG,
model::AbstractModel,
Expand All @@ -191,6 +213,7 @@ function DynamicPPL.initialstep(
nparticles::Int,
kwargs...,
)
check_model_kwargs(model)
# Reset the VarInfo.
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
set_all_del!(vi)
Expand Down Expand Up @@ -316,6 +339,7 @@ function DynamicPPL.initialstep(
vi::AbstractVarInfo;
kwargs...,
)
check_model_kwargs(model)
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
# Reset the VarInfo before new sweep
set_all_del!(vi)
Expand Down Expand Up @@ -552,7 +576,7 @@ end
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
# call stack.
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
Libtask.@might_produce(DynamicPPL.accloglikelihood!!)
function Libtask.might_produce(
::Type{
<:Tuple{
Expand All @@ -564,25 +588,13 @@ function Libtask.might_produce(
)
return true
end
function Libtask.might_produce(
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
)
return true
end
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
Libtask.@might_produce(DynamicPPL.accumulate_observe!!)
Libtask.@might_produce(DynamicPPL.tilde_observe!!)
# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext?
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
function Libtask.might_produce(
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
)
return true
end
function Libtask.might_produce(
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
)
return true
end
Libtask.@might_produce(DynamicPPL.tilde_assume!!)
Libtask.@might_produce(DynamicPPL.tilde_assume)
Libtask.@might_produce(DynamicPPL.evaluate!!)
Libtask.@might_produce(DynamicPPL.evaluate_threadsafe!!)
Libtask.@might_produce(DynamicPPL.evaluate_threadunsafe!!)
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
Loading