Skip to content

Commit 8b4e7f2

Browse files
committed
[skip ci] Enable keyword arguments for particle methods
1 parent 5a3f7aa commit 8b4e7f2

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ function unset_all_del!(vi::AbstractVarInfo)
3636
return nothing
3737
end
3838

39-
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
40-
AdvancedPS.AbstractGenericModel
39+
struct TracedModel{
40+
S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple
41+
} <: AdvancedPS.AbstractGenericModel
4142
model::M
4243
sampler::S
4344
varinfo::V
44-
evaluator::E
45+
fargs::T
46+
kwargs::NT
4547
end
4648

4749
function TracedModel(
@@ -53,13 +55,8 @@ function TracedModel(
5355
spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context)
5456
spl_model = DynamicPPL.contextualize(model, spl_context)
5557
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo)
56-
if kwargs !== nothing && !isempty(kwargs)
57-
error(
58-
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
59-
)
60-
end
61-
evaluator = (spl_model.f, args...)
62-
return TracedModel(spl_model, sampler, varinfo, evaluator)
58+
fargs = (spl_model.f, args...)
59+
return TracedModel(spl_model, sampler, varinfo, fargs, kwargs)
6360
end
6461

6562
function AdvancedPS.advance!(
@@ -91,9 +88,9 @@ function AdvancedPS.reset_model(trace::TracedModel)
9188
return trace
9289
end
9390

94-
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
91+
function Libtask.TapedTask(taped_globals, model::TracedModel)
9592
return Libtask.TapedTask(
96-
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
93+
taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs...
9794
)
9895
end
9996

@@ -191,6 +188,9 @@ function DynamicPPL.initialstep(
191188
nparticles::Int,
192189
kwargs...,
193190
)
191+
if !isempty(model.defaults)
192+
@warn "The use of particle methods for models with keyword arguments requires special care. Please see <documentation link> for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for."
193+
end
194194
# Reset the VarInfo.
195195
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
196196
set_all_del!(vi)
@@ -316,6 +316,9 @@ function DynamicPPL.initialstep(
316316
vi::AbstractVarInfo;
317317
kwargs...,
318318
)
319+
if !isempty(model.defaults)
320+
@warn "The use of particle methods for models with keyword arguments requires special care. Please see <documentation link> for more details and be sure to check the results you obtain to make sure that observations are being properly accounted for."
321+
end
319322
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
320323
# Reset the VarInfo before new sweep
321324
set_all_del!(vi)
@@ -586,3 +589,11 @@ function Libtask.might_produce(
586589
return true
587590
end
588591
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
592+
# This method deals with models that have keyword arguments, although it is alone not
593+
# sufficient to make Libtask fully work with keyword arguments. In here, the second argument
594+
# to Core.kwcall here is `model.f`.
595+
function Libtask.might_produce(
596+
::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,<:Any,<:DynamicPPL.Model,Vararg}}
597+
)
598+
return true
599+
end

0 commit comments

Comments
 (0)