Skip to content

Commit b0df6a6

Browse files
committed
Progress in DPPL 0.37 compat for particle MCMC
1 parent 11a2a31 commit b0df6a6

File tree

1 file changed

+99
-54
lines changed

1 file changed

+99
-54
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 99 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ end
3333
function AdvancedPS.advance!(
3434
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
3535
)
36-
# Make sure we load/reset the rng in the new replaying mechanism
37-
trace = Accessors.@set trace.model.f.varinfo = DynamicPPL.increment_num_produce!!(
38-
trace.model.f.varinfo
36+
# We want to increment num produce for the VarInfo stored in the trace. The trace is
37+
# mutable, so we create a new model with the incremented VarInfo and set it in the trace
38+
model = trace.model
39+
model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!(
40+
model.f.varinfo
3941
)
42+
trace.model = model
43+
# Make sure we load/reset the rng in the new replaying mechanism
4044
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
4145
score = consume(trace.model.ctask)
4246
if score === nothing
@@ -55,10 +59,6 @@ function AdvancedPS.reset_model(trace::TracedModel)
5559
return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo)
5660
end
5761

58-
function AdvancedPS.reset_logprob!(trace::TracedModel)
59-
return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo)
60-
end
61-
6262
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
6363
return Libtask.TapedTask(
6464
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
@@ -390,78 +390,124 @@ function DynamicPPL.use_threadsafe_eval(
390390
return false
391391
end
392392

393-
function trace_local_varinfo_maybe(varinfo)
394-
try
395-
trace = Libtask.get_taped_globals(Any).other
396-
return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo
393+
"""
394+
get_trace_local_varinfo_maybe(vi::AbstractVarInfo)
395+
396+
Get the `Trace` local varinfo if one exists.
397+
398+
If executed within a `TapedTask`, return the `varinfo` stored in the "taped globals" of the
399+
task, otherwise return `vi`.
400+
"""
401+
function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo)
402+
trace = try
403+
Libtask.get_taped_globals(Any).other
397404
catch e
398-
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
399-
if e == KeyError(:task_variable)
400-
return varinfo
401-
else
402-
rethrow(e)
403-
end
405+
e == KeyError(:task_variable) ? nothing : rethrow(e)
404406
end
407+
return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo
405408
end
406409

407-
function trace_local_rng_maybe(rng::Random.AbstractRNG)
408-
try
409-
return Libtask.get_taped_globals(Any).rng
410+
"""
411+
get_trace_local_varinfo_maybe(rng::Random.AbstractRNG)
412+
413+
Get the `Trace` local rng if one exists.
414+
415+
If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the
416+
task, otherwise return `vi`.
417+
"""
418+
function get_trace_local_rng_maybe(rng::Random.AbstractRNG)
419+
return try
420+
Libtask.get_taped_globals(Any).rng
410421
catch e
411-
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
412-
if e == KeyError(:task_variable)
413-
return rng
414-
else
415-
rethrow(e)
416-
end
422+
e == KeyError(:task_variable) ? rng : rethrow(e)
417423
end
418424
end
419425

420-
# TODO(DPPL0.37/penelopeysm) The whole tilde pipeline for particle MCMC needs to be
421-
# thoroughly fixed.
426+
"""
427+
set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
428+
429+
Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`.
430+
431+
If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the
432+
task. Otherwise do nothing.
433+
"""
434+
function set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
435+
# TODO(mhauru) This should be done in a try-catch block, as in the commented out code.
436+
# However, Libtask currently can't handle this block.
437+
trace = #try
438+
Libtask.get_taped_globals(Any).other
439+
# catch e
440+
# e == KeyError(:task_variable) ? nothing : rethrow(e)
441+
# end
442+
if trace !== nothing
443+
model = trace.model
444+
model = Accessors.@set model.f.varinfo = vi
445+
trace.model = model
446+
end
447+
return nothing
448+
end
449+
422450
function DynamicPPL.assume(
423-
rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo
451+
rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo
424452
)
425-
vi = trace_local_varinfo_maybe(_vi)
426-
trng = trace_local_rng_maybe(rng)
453+
arg_vi_id = objectid(vi)
454+
vi = get_trace_local_varinfo_maybe(vi)
455+
using_local_vi = objectid(vi) == arg_vi_id
456+
457+
trng = get_trace_local_rng_maybe(rng)
427458

428459
if ~haskey(vi, vn)
429460
r = rand(trng, dist)
430-
push!!(vi, vn, r, dist)
461+
vi = push!!(vi, vn, r, dist)
431462
elseif DynamicPPL.is_flagged(vi, vn, "del")
432463
DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent
433464
r = rand(trng, dist)
434465
vi[vn] = DynamicPPL.tovec(r)
466+
# TODO(mhauru):
467+
# The below is the only line that differs from assume called on SampleFromPrior.
468+
# Could we just call assume on SampleFromPrior and then `setorder!!` after that?
435469
vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi))
436470
else
437471
r = vi[vn]
438472
end
439-
# TODO: call accumulate_assume?!
473+
474+
# TODO(mhauru) This get/set business is awful.
475+
old_logp = DynamicPPL.getlogprior(vi)
476+
vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist)
477+
vi = DynamicPPL.setlogprior!!(vi, old_logp)
478+
479+
# TODO(mhauru) Rather than this if-block, we should use try-catch within
480+
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
481+
# hence this.
482+
if !using_local_vi
483+
set_trace_local_varinfo_maybe(vi)
484+
end
440485
return r, vi
441486
end
442487

443-
# TODO(mhauru) Fix this.
444-
# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
445-
# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
446-
# return logpdf(dist, value), trace_local_varinfo_maybe(vi)
447-
# end
448-
449-
function DynamicPPL.acclogp!!(
450-
context::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}},
451-
varinfo::AbstractVarInfo,
452-
logp,
488+
function DynamicPPL.tilde_observe!!(
489+
ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi
453490
)
454-
varinfo_trace = trace_local_varinfo_maybe(varinfo)
455-
return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp)
456-
end
491+
arg_vi_id = objectid(vi)
492+
vi = get_trace_local_varinfo_maybe(vi)
493+
using_local_vi = objectid(vi) == arg_vi_id
494+
495+
# TODO(mhauru) This get/set business is awful.
496+
old_logp = DynamicPPL.getloglikelihood(vi)
497+
left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi)
498+
new_loglikelihood = DynamicPPL.getloglikelihood(vi) - old_logp
499+
vi = DynamicPPL.setloglikelihood!!(vi, old_logp)
500+
501+
# TODO(mhauru) Rather than this if-block, we should use try-catch within
502+
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
503+
# hence this.
504+
if !using_local_vi
505+
set_trace_local_varinfo_maybe(vi)
506+
end
457507

458-
# TODO(mhauru) Fix this.
459-
# function DynamicPPL.acclogp_observe!!(
460-
# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
461-
# )
462-
# Libtask.produce(logp)
463-
# return trace_local_varinfo_maybe(varinfo)
464-
# end
508+
Libtask.produce(new_loglikelihood)
509+
return left, vi
510+
end
465511

466512
# Convenient constructor
467513
function AdvancedPS.Trace(
@@ -483,7 +529,6 @@ end
483529
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
484530
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
485531
# `acclogp_observe!!` which is what calls `produce` and go up the call stack.
486-
# Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true
487532
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
488533
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
489534
function Libtask.might_produce(

0 commit comments

Comments
 (0)