@@ -43,11 +43,7 @@ function AdvancedPS.advance!(
43
43
# Make sure we load/reset the rng in the new replaying mechanism
44
44
isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
45
45
score = consume (trace. model. ctask)
46
- if score === nothing
47
- return nothing
48
- else
49
- return score + DynamicPPL. getlogjoint (trace. model. f. varinfo)
50
- end
46
+ return score
51
47
end
52
48
53
49
function AdvancedPS. delete_retained! (trace:: TracedModel )
114
110
115
111
function SMCTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , weight)
116
112
theta = getparams (model, vi)
117
-
118
- # This is pretty useless since we reset the log probability continuously in the
119
- # particle sweep.
120
113
lp = DynamicPPL. getlogjoint (vi)
121
-
122
114
return SMCTransition (theta, lp, weight)
123
115
end
124
116
@@ -293,11 +285,7 @@ varinfo(state::PGState) = state.vi
293
285
294
286
function PGTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , logevidence)
295
287
theta = getparams (model, vi)
296
-
297
- # This is pretty useless since we reset the log probability continuously in the
298
- # particle sweep.
299
288
lp = DynamicPPL. getlogjoint (vi)
300
-
301
289
return PGTransition (theta, lp, logevidence)
302
290
end
303
291
@@ -316,6 +304,7 @@ function DynamicPPL.initialstep(
316
304
vi:: AbstractVarInfo ;
317
305
kwargs... ,
318
306
)
307
+ vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
319
308
# Reset the VarInfo before new sweep
320
309
vi = DynamicPPL. reset_num_produce!! (vi)
321
310
DynamicPPL. set_retained_vns_del! (vi)
@@ -471,10 +460,7 @@ function DynamicPPL.assume(
471
460
r = vi[vn]
472
461
end
473
462
474
- # TODO (mhauru) This get/set business is awful.
475
- old_logp = DynamicPPL. getlogprior (vi)
476
463
vi = DynamicPPL. accumulate_assume!! (vi, r, 0 , vn, dist)
477
- vi = DynamicPPL. setlogprior!! (vi, old_logp)
478
464
479
465
# TODO (mhauru) Rather than this if-block, we should use try-catch within
480
466
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
@@ -492,20 +478,14 @@ function DynamicPPL.tilde_observe!!(
492
478
vi = get_trace_local_varinfo_maybe (vi)
493
479
using_local_vi = objectid (vi) == arg_vi_id
494
480
495
- # TODO (mhauru) This get/set business is awful.
496
- old_logp = DynamicPPL. getloglikelihood (vi)
497
481
left, vi = DynamicPPL. tilde_observe!! (ctx. context, right, left, vn, vi)
498
- new_loglikelihood = DynamicPPL. getloglikelihood (vi) - old_logp
499
- vi = DynamicPPL. setloglikelihood!! (vi, old_logp)
500
482
501
483
# TODO (mhauru) Rather than this if-block, we should use try-catch within
502
484
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
503
485
# hence this.
504
- if ! using_local_vi
505
- set_trace_local_varinfo_maybe (vi)
506
- end
507
-
508
- Libtask. produce (new_loglikelihood)
486
+ # if !using_local_vi
487
+ # set_trace_local_varinfo_maybe(vi)
488
+ # end
509
489
return left, vi
510
490
end
511
491
528
508
# of these won't be needed, because of inlining and the fact that `might_produce` is only
529
509
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
530
510
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
531
- # `acclogp_observe!!` which is what calls `produce` and go up the call stack.
511
+ # adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
512
+ # call stack.
513
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}} ) = true
514
+ function Libtask. might_produce (
515
+ :: Type {
516
+ <: Tuple {
517
+ typeof (Base.:+ ),
518
+ ProduceLogLikelihoodAccumulator,
519
+ DynamicPPL. LogLikelihoodAccumulator,
520
+ },
521
+ },
522
+ )
523
+ return true
524
+ end
525
+ function Libtask. might_produce (
526
+ :: Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
527
+ )
528
+ return true
529
+ end
532
530
Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
533
531
Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
534
532
function Libtask. might_produce (
@@ -542,3 +540,97 @@ function Libtask.might_produce(
542
540
return true
543
541
end
544
542
Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
543
+
544
+ """
545
+ ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
546
+
547
+ Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
548
+
549
+ # Fields
550
+ $(TYPEDFIELDS)
551
+ """
552
+ struct ProduceLogLikelihoodAccumulator{T<: Real } <: DynamicPPL.AbstractAccumulator
553
+ " the scalar log likelihood value"
554
+ logp:: T
555
+ end
556
+
557
+ """
558
+ ProduceLogLikelihoodAccumulator{T}()
559
+
560
+ Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero.
561
+ """
562
+ ProduceLogLikelihoodAccumulator {T} () where {T<: Real } =
563
+ ProduceLogLikelihoodAccumulator (zero (T))
564
+ function ProduceLogLikelihoodAccumulator ()
565
+ return ProduceLogLikelihoodAccumulator {DynamicPPL.LogProbType} ()
566
+ end
567
+
568
+ Base. copy (acc:: ProduceLogLikelihoodAccumulator ) = acc
569
+
570
+ function Base. show (io:: IO , acc:: ProduceLogLikelihoodAccumulator )
571
+ return print (io, " ProduceLogLikelihoodAccumulator($(repr (acc. logp)) )" )
572
+ end
573
+
574
+ function Base.:(== )(
575
+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
576
+ )
577
+ return acc1. logp == acc2. logp
578
+ end
579
+
580
+ function Base. isequal (
581
+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
582
+ )
583
+ return isequal (acc1. logp, acc2. logp)
584
+ end
585
+
586
+ function Base. hash (acc:: ProduceLogLikelihoodAccumulator , h:: UInt )
587
+ return hash ((ProduceLogLikelihoodAccumulator, acc. logp), h)
588
+ end
589
+
590
+ # Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two
591
+ # can be used in a given VarInfo.
592
+ DynamicPPL. accumulator_name (:: Type{<:ProduceLogLikelihoodAccumulator} ) = :LogLikelihood
593
+
594
+ function DynamicPPL. split (:: ProduceLogLikelihoodAccumulator{T} ) where {T}
595
+ return ProduceLogLikelihoodAccumulator (zero (T))
596
+ end
597
+
598
+ function DynamicPPL. combine (
599
+ acc:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
600
+ )
601
+ return ProduceLogLikelihoodAccumulator (acc. logp + acc2. logp)
602
+ end
603
+
604
+ function Base.:+ (
605
+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: DynamicPPL.LogLikelihoodAccumulator
606
+ )
607
+ Libtask. produce (acc2. logp)
608
+ return ProduceLogLikelihoodAccumulator (acc1. logp + acc2. logp)
609
+ end
610
+
611
+ function Base. zero (acc:: ProduceLogLikelihoodAccumulator )
612
+ return ProduceLogLikelihoodAccumulator (zero (acc. logp))
613
+ end
614
+
615
+ function DynamicPPL. accumulate_assume!! (
616
+ acc:: ProduceLogLikelihoodAccumulator , val, logjac, vn, right
617
+ )
618
+ return acc
619
+ end
620
+ function DynamicPPL. accumulate_observe!! (
621
+ acc:: ProduceLogLikelihoodAccumulator , right, left, vn
622
+ )
623
+ return acc +
624
+ DynamicPPL. LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
625
+ end
626
+
627
+ function Base. convert (
628
+ :: Type{ProduceLogLikelihoodAccumulator{T}} , acc:: ProduceLogLikelihoodAccumulator
629
+ ) where {T}
630
+ return ProduceLogLikelihoodAccumulator (convert (T, acc. logp))
631
+ end
632
+ function DynamicPPL. convert_eltype (
633
+ :: Type{T} , acc:: ProduceLogLikelihoodAccumulator
634
+ ) where {T}
635
+ return ProduceLogLikelihoodAccumulator (convert (T, acc. logp))
636
+ end
0 commit comments