Skip to content

Commit 9c66ef2

Browse files
committed
[skip ci] Enable keyword arguments for particle methods
1 parent 5a3f7aa commit 9c66ef2

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ function unset_all_del!(vi::AbstractVarInfo)
3636
return nothing
3737
end
3838

39-
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
40-
AdvancedPS.AbstractGenericModel
39+
struct TracedModel{
40+
S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple
41+
} <: AdvancedPS.AbstractGenericModel
4142
model::M
4243
sampler::S
4344
varinfo::V
44-
evaluator::E
45+
fargs::T
46+
kwargs::NT
4547
end
4648

4749
function TracedModel(
@@ -53,13 +55,11 @@ function TracedModel(
5355
spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context)
5456
spl_model = DynamicPPL.contextualize(model, spl_context)
5557
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo)
56-
if kwargs !== nothing && !isempty(kwargs)
57-
error(
58-
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
59-
)
58+
if !isempty(kwargs)
59+
@warn "The use of particle methods for models with keyword arguments requires special care. Please see <documentation link> for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for."
6060
end
61-
evaluator = (spl_model.f, args...)
62-
return TracedModel(spl_model, sampler, varinfo, evaluator)
61+
fargs = (spl_model.f, args...)
62+
return TracedModel(spl_model, sampler, varinfo, fargs, kwargs)
6363
end
6464

6565
function AdvancedPS.advance!(
@@ -91,9 +91,9 @@ function AdvancedPS.reset_model(trace::TracedModel)
9191
return trace
9292
end
9393

94-
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
94+
function Libtask.TapedTask(taped_globals, model::TracedModel)
9595
return Libtask.TapedTask(
96-
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
96+
taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs...
9797
)
9898
end
9999

@@ -586,3 +586,11 @@ function Libtask.might_produce(
586586
return true
587587
end
588588
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
589+
# This method deals with models that have keyword arguments, although it is alone not
590+
# sufficient to make Libtask fully work with keyword arguments. In here, the second argument
591+
# to Core.kwcall here is `model.f`.
592+
function Libtask.might_produce(
593+
::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,<:Any,<:DynamicPPL.Model,Vararg}}
594+
)
595+
return true
596+
end

0 commit comments

Comments
 (0)