@@ -43,11 +43,7 @@ function AdvancedPS.advance!(
4343 # Make sure we load/reset the rng in the new replaying mechanism
4444 isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
4545 score = consume (trace. model. ctask)
46- if score === nothing
47- return nothing
48- else
49- return score + DynamicPPL. getlogjoint (trace. model. f. varinfo)
50- end
46+ return score
5147end
5248
5349function AdvancedPS. delete_retained! (trace:: TracedModel )
114110
115111function SMCTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , weight)
116112 theta = getparams (model, vi)
117-
118- # This is pretty useless since we reset the log probability continuously in the
119- # particle sweep.
120113 lp = DynamicPPL. getlogjoint (vi)
121-
122114 return SMCTransition (theta, lp, weight)
123115end
124116
@@ -293,11 +285,7 @@ varinfo(state::PGState) = state.vi
293285
294286function PGTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , logevidence)
295287 theta = getparams (model, vi)
296-
297- # This is pretty useless since we reset the log probability continuously in the
298- # particle sweep.
299288 lp = DynamicPPL. getlogjoint (vi)
300-
301289 return PGTransition (theta, lp, logevidence)
302290end
303291
@@ -316,6 +304,7 @@ function DynamicPPL.initialstep(
316304 vi:: AbstractVarInfo ;
317305 kwargs... ,
318306)
307+ vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
319308 # Reset the VarInfo before new sweep
320309 vi = DynamicPPL. reset_num_produce!! (vi)
321310 DynamicPPL. set_retained_vns_del! (vi)
@@ -471,10 +460,7 @@ function DynamicPPL.assume(
471460 r = vi[vn]
472461 end
473462
474- # TODO (mhauru) This get/set business is awful.
475- old_logp = DynamicPPL. getlogprior (vi)
476463 vi = DynamicPPL. accumulate_assume!! (vi, r, 0 , vn, dist)
477- vi = DynamicPPL. setlogprior!! (vi, old_logp)
478464
479465 # TODO (mhauru) Rather than this if-block, we should use try-catch within
480466 # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
@@ -492,20 +478,14 @@ function DynamicPPL.tilde_observe!!(
492478 vi = get_trace_local_varinfo_maybe (vi)
493479 using_local_vi = objectid (vi) == arg_vi_id
494480
495- # TODO (mhauru) This get/set business is awful.
496- old_logp = DynamicPPL. getloglikelihood (vi)
497481 left, vi = DynamicPPL. tilde_observe!! (ctx. context, right, left, vn, vi)
498- new_loglikelihood = DynamicPPL. getloglikelihood (vi) - old_logp
499- vi = DynamicPPL. setloglikelihood!! (vi, old_logp)
500482
501483 # TODO (mhauru) Rather than this if-block, we should use try-catch within
502484 # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
503485 # hence this.
504- if ! using_local_vi
505- set_trace_local_varinfo_maybe (vi)
506- end
507-
508- Libtask. produce (new_loglikelihood)
486+ # if !using_local_vi
487+ # set_trace_local_varinfo_maybe(vi)
488+ # end
509489 return left, vi
510490end
511491
528508# of these won't be needed, because of inlining and the fact that `might_produce` is only
529509# called on `:invoke` expressions rather than `:call`s, but since those are implementation
530510# details of the compiler, we set a bunch of methods as might_produce = true. We start with
531- # `acclogp_observe!!` which is what calls `produce` and go up the call stack.
511+ # adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
512+ # call stack.
513+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}} ) = true
514+ function Libtask. might_produce (
515+ :: Type {
516+ <: Tuple {
517+ typeof (Base.:+ ),
518+ ProduceLogLikelihoodAccumulator,
519+ DynamicPPL. LogLikelihoodAccumulator,
520+ },
521+ },
522+ )
523+ return true
524+ end
525+ function Libtask. might_produce (
526+ :: Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
527+ )
528+ return true
529+ end
532530Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
533531Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
534532function Libtask. might_produce (
@@ -542,3 +540,97 @@ function Libtask.might_produce(
542540 return true
543541end
544542Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
543+
544+ """
545+ ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
546+
547+ Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
548+
549+ # Fields
550+ $(TYPEDFIELDS)
551+ """
552+ struct ProduceLogLikelihoodAccumulator{T<: Real } <: DynamicPPL.AbstractAccumulator
553+ " the scalar log likelihood value"
554+ logp:: T
555+ end
556+
557+ """
558+ ProduceLogLikelihoodAccumulator{T}()
559+
560+ Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero.
561+ """
562+ ProduceLogLikelihoodAccumulator {T} () where {T<: Real } =
563+ ProduceLogLikelihoodAccumulator (zero (T))
564+ function ProduceLogLikelihoodAccumulator ()
565+ return ProduceLogLikelihoodAccumulator {DynamicPPL.LogProbType} ()
566+ end
567+
568+ Base. copy (acc:: ProduceLogLikelihoodAccumulator ) = acc
569+
570+ function Base. show (io:: IO , acc:: ProduceLogLikelihoodAccumulator )
571+ return print (io, " ProduceLogLikelihoodAccumulator($(repr (acc. logp)) )" )
572+ end
573+
574+ function Base.:(== )(
575+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
576+ )
577+ return acc1. logp == acc2. logp
578+ end
579+
580+ function Base. isequal (
581+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
582+ )
583+ return isequal (acc1. logp, acc2. logp)
584+ end
585+
586+ function Base. hash (acc:: ProduceLogLikelihoodAccumulator , h:: UInt )
587+ return hash ((ProduceLogLikelihoodAccumulator, acc. logp), h)
588+ end
589+
590+ # Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two
591+ # can be used in a given VarInfo.
592+ DynamicPPL. accumulator_name (:: Type{<:ProduceLogLikelihoodAccumulator} ) = :LogLikelihood
593+
594+ function DynamicPPL. split (:: ProduceLogLikelihoodAccumulator{T} ) where {T}
595+ return ProduceLogLikelihoodAccumulator (zero (T))
596+ end
597+
598+ function DynamicPPL. combine (
599+ acc:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
600+ )
601+ return ProduceLogLikelihoodAccumulator (acc. logp + acc2. logp)
602+ end
603+
604+ function Base.:+ (
605+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: DynamicPPL.LogLikelihoodAccumulator
606+ )
607+ Libtask. produce (acc2. logp)
608+ return ProduceLogLikelihoodAccumulator (acc1. logp + acc2. logp)
609+ end
610+
611+ function Base. zero (acc:: ProduceLogLikelihoodAccumulator )
612+ return ProduceLogLikelihoodAccumulator (zero (acc. logp))
613+ end
614+
615+ function DynamicPPL. accumulate_assume!! (
616+ acc:: ProduceLogLikelihoodAccumulator , val, logjac, vn, right
617+ )
618+ return acc
619+ end
620+ function DynamicPPL. accumulate_observe!! (
621+ acc:: ProduceLogLikelihoodAccumulator , right, left, vn
622+ )
623+ return acc +
624+ DynamicPPL. LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
625+ end
626+
627+ function Base. convert (
628+ :: Type{ProduceLogLikelihoodAccumulator{T}} , acc:: ProduceLogLikelihoodAccumulator
629+ ) where {T}
630+ return ProduceLogLikelihoodAccumulator (convert (T, acc. logp))
631+ end
632+ function DynamicPPL. convert_eltype (
633+ :: Type{T} , acc:: ProduceLogLikelihoodAccumulator
634+ ) where {T}
635+ return ProduceLogLikelihoodAccumulator (convert (T, acc. logp))
636+ end
0 commit comments