33
33
function AdvancedPS. advance! (
34
34
trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} , isref:: Bool = false
35
35
)
36
- # Make sure we load/reset the rng in the new replaying mechanism
37
- trace = Accessors. @set trace. model. f. varinfo = DynamicPPL. increment_num_produce!! (
38
- trace. model. f. varinfo
36
+ # We want to increment num produce for the VarInfo stored in the trace. The trace is
37
+ # mutable, so we create a new model with the incremented VarInfo and set it in the trace
38
+ model = trace. model
39
+ model = Accessors. @set model. f. varinfo = DynamicPPL. increment_num_produce!! (
40
+ model. f. varinfo
39
41
)
42
+ trace. model = model
43
+ # Make sure we load/reset the rng in the new replaying mechanism
40
44
isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
41
45
score = consume (trace. model. ctask)
42
46
if score === nothing
@@ -55,10 +59,6 @@ function AdvancedPS.reset_model(trace::TracedModel)
55
59
return Accessors. @set trace. varinfo = DynamicPPL. reset_num_produce!! (trace. varinfo)
56
60
end
57
61
58
- function AdvancedPS. reset_logprob! (trace:: TracedModel )
59
- return Accessors. @set trace. model. varinfo = DynamicPPL. resetlogp!! (trace. model. varinfo)
60
- end
61
-
62
62
function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
63
63
return Libtask. TapedTask (
64
64
taped_globals, model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs...
@@ -390,78 +390,124 @@ function DynamicPPL.use_threadsafe_eval(
390
390
return false
391
391
end
392
392
393
- function trace_local_varinfo_maybe (varinfo)
394
- try
395
- trace = Libtask. get_taped_globals (Any). other
396
- return (trace === nothing ? varinfo : trace. model. f. varinfo):: AbstractVarInfo
393
+ """
394
+ get_trace_local_varinfo_maybe(vi::AbstractVarInfo)
395
+
396
+ Get the `Trace` local varinfo if one exists.
397
+
398
+ If executed within a `TapedTask`, return the `varinfo` stored in the "taped globals" of the
399
+ task, otherwise return `vi`.
400
+ """
401
+ function get_trace_local_varinfo_maybe (varinfo:: AbstractVarInfo )
402
+ trace = try
403
+ Libtask. get_taped_globals (Any). other
397
404
catch e
398
- # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
399
- if e == KeyError (:task_variable )
400
- return varinfo
401
- else
402
- rethrow (e)
403
- end
405
+ e == KeyError (:task_variable ) ? nothing : rethrow (e)
404
406
end
407
+ return (trace === nothing ? varinfo : trace. model. f. varinfo):: AbstractVarInfo
405
408
end
406
409
407
- function trace_local_rng_maybe (rng:: Random.AbstractRNG )
408
- try
409
- return Libtask. get_taped_globals (Any). rng
410
+ """
411
+ get_trace_local_varinfo_maybe(rng::Random.AbstractRNG)
412
+
413
+ Get the `Trace` local rng if one exists.
414
+
415
+ If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the
416
+ task, otherwise return `vi`.
417
+ """
418
+ function get_trace_local_rng_maybe (rng:: Random.AbstractRNG )
419
+ return try
420
+ Libtask. get_taped_globals (Any). rng
410
421
catch e
411
- # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
412
- if e == KeyError (:task_variable )
413
- return rng
414
- else
415
- rethrow (e)
416
- end
422
+ e == KeyError (:task_variable ) ? rng : rethrow (e)
417
423
end
418
424
end
419
425
420
- # TODO (DPPL0.37/penelopeysm) The whole tilde pipeline for particle MCMC needs to be
421
- # thoroughly fixed.
426
+ """
427
+ set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
428
+
429
+ Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`.
430
+
431
+ If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the
432
+ task. Otherwise do nothing.
433
+ """
434
+ function set_trace_local_varinfo_maybe (vi:: AbstractVarInfo )
435
+ # TODO (mhauru) This should be done in a try-catch block, as in the commented out code.
436
+ # However, Libtask currently can't handle this block.
437
+ trace = # try
438
+ Libtask. get_taped_globals (Any). other
439
+ # catch e
440
+ # e == KeyError(:task_variable) ? nothing : rethrow(e)
441
+ # end
442
+ if trace != = nothing
443
+ model = trace. model
444
+ model = Accessors. @set model. f. varinfo = vi
445
+ trace. model = model
446
+ end
447
+ return nothing
448
+ end
449
+
422
450
function DynamicPPL. assume (
423
- rng, :: Sampler{<:Union{PG,SMC}} , dist:: Distribution , vn:: VarName , _vi :: AbstractVarInfo
451
+ rng, :: Sampler{<:Union{PG,SMC}} , dist:: Distribution , vn:: VarName , vi :: AbstractVarInfo
424
452
)
425
- vi = trace_local_varinfo_maybe (_vi)
426
- trng = trace_local_rng_maybe (rng)
453
+ arg_vi_id = objectid (vi)
454
+ vi = get_trace_local_varinfo_maybe (vi)
455
+ using_local_vi = objectid (vi) == arg_vi_id
456
+
457
+ trng = get_trace_local_rng_maybe (rng)
427
458
428
459
if ~ haskey (vi, vn)
429
460
r = rand (trng, dist)
430
- push!! (vi, vn, r, dist)
461
+ vi = push!! (vi, vn, r, dist)
431
462
elseif DynamicPPL. is_flagged (vi, vn, " del" )
432
463
DynamicPPL. unset_flag! (vi, vn, " del" ) # Reference particle parent
433
464
r = rand (trng, dist)
434
465
vi[vn] = DynamicPPL. tovec (r)
466
+ # TODO (mhauru):
467
+ # The below is the only line that differs from assume called on SampleFromPrior.
468
+ # Could we just call assume on SampleFromPrior and then `setorder!!` after that?
435
469
vi = DynamicPPL. setorder!! (vi, vn, DynamicPPL. get_num_produce (vi))
436
470
else
437
471
r = vi[vn]
438
472
end
439
- # TODO : call accumulate_assume?!
473
+
474
+ # TODO (mhauru) This get/set business is awful.
475
+ old_logp = DynamicPPL. getlogprior (vi)
476
+ vi = DynamicPPL. accumulate_assume!! (vi, r, 0 , vn, dist)
477
+ vi = DynamicPPL. setlogprior!! (vi, old_logp)
478
+
479
+ # TODO (mhauru) Rather than this if-block, we should use try-catch within
480
+ # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
481
+ # hence this.
482
+ if ! using_local_vi
483
+ set_trace_local_varinfo_maybe (vi)
484
+ end
440
485
return r, vi
441
486
end
442
487
443
- # TODO (mhauru) Fix this.
444
- # function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
445
- # # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
446
- # return logpdf(dist, value), trace_local_varinfo_maybe(vi)
447
- # end
448
-
449
- function DynamicPPL. acclogp!! (
450
- context:: DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}} ,
451
- varinfo:: AbstractVarInfo ,
452
- logp,
488
+ function DynamicPPL. tilde_observe!! (
489
+ ctx:: DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}} , right, left, vn, vi
453
490
)
454
- varinfo_trace = trace_local_varinfo_maybe (varinfo)
455
- return DynamicPPL. acclogp!! (DynamicPPL. childcontext (context), varinfo_trace, logp)
456
- end
491
+ arg_vi_id = objectid (vi)
492
+ vi = get_trace_local_varinfo_maybe (vi)
493
+ using_local_vi = objectid (vi) == arg_vi_id
494
+
495
+ # TODO (mhauru) This get/set business is awful.
496
+ old_logp = DynamicPPL. getloglikelihood (vi)
497
+ left, vi = DynamicPPL. tilde_observe!! (ctx. context, right, left, vn, vi)
498
+ new_loglikelihood = DynamicPPL. getloglikelihood (vi) - old_logp
499
+ vi = DynamicPPL. setloglikelihood!! (vi, old_logp)
500
+
501
+ # TODO (mhauru) Rather than this if-block, we should use try-catch within
502
+ # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
503
+ # hence this.
504
+ if ! using_local_vi
505
+ set_trace_local_varinfo_maybe (vi)
506
+ end
457
507
458
- # TODO (mhauru) Fix this.
459
- # function DynamicPPL.acclogp_observe!!(
460
- # context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
461
- # )
462
- # Libtask.produce(logp)
463
- # return trace_local_varinfo_maybe(varinfo)
464
- # end
508
+ Libtask. produce (new_loglikelihood)
509
+ return left, vi
510
+ end
465
511
466
512
# Convenient constructor
467
513
function AdvancedPS. Trace (
483
529
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
484
530
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
485
531
# `acclogp_observe!!` which is what calls `produce` and go up the call stack.
486
- # Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true
487
532
Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
488
533
Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
489
534
function Libtask. might_produce (
0 commit comments