@@ -580,3 +580,97 @@ function Libtask.might_produce(
580
580
return true
581
581
end
582
582
Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
583
+
584
+ """
585
+ ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
586
+
587
+ Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
588
+
589
+ # Fields
590
+ $(TYPEDFIELDS)
591
+ """
592
+ struct ProduceLogLikelihoodAccumulator{T<: Real } <: DynamicPPL.AbstractAccumulator
593
+ " the scalar log likelihood value"
594
+ logp:: T
595
+ end
596
+
597
+ """
598
+ ProduceLogLikelihoodAccumulator{T}()
599
+
600
+ Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero.
601
+ """
602
+ ProduceLogLikelihoodAccumulator {T} () where {T<: Real } =
603
+ ProduceLogLikelihoodAccumulator (zero (T))
604
+ function ProduceLogLikelihoodAccumulator ()
605
+ return ProduceLogLikelihoodAccumulator {DynamicPPL.LogProbType} ()
606
+ end
607
+
608
+ Base. copy (acc:: ProduceLogLikelihoodAccumulator ) = acc
609
+
610
+ function Base. show (io:: IO , acc:: ProduceLogLikelihoodAccumulator )
611
+ return print (io, " ProduceLogLikelihoodAccumulator($(repr (acc. logp)) )" )
612
+ end
613
+
614
+ function Base.:(== )(
615
+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
616
+ )
617
+ return acc1. logp == acc2. logp
618
+ end
619
+
620
+ function Base. isequal (
621
+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
622
+ )
623
+ return isequal (acc1. logp, acc2. logp)
624
+ end
625
+
626
+ function Base. hash (acc:: ProduceLogLikelihoodAccumulator , h:: UInt )
627
+ return hash ((ProduceLogLikelihoodAccumulator, acc. logp), h)
628
+ end
629
+
630
+ # Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two
631
+ # can be used in a given VarInfo.
632
+ DynamicPPL. accumulator_name (:: Type{<:ProduceLogLikelihoodAccumulator} ) = :LogLikelihood
633
+
634
+ function DynamicPPL. split (:: ProduceLogLikelihoodAccumulator{T} ) where {T}
635
+ return ProduceLogLikelihoodAccumulator (zero (T))
636
+ end
637
+
638
+ function DynamicPPL. combine (
639
+ acc:: ProduceLogLikelihoodAccumulator , acc2:: ProduceLogLikelihoodAccumulator
640
+ )
641
+ return ProduceLogLikelihoodAccumulator (acc. logp + acc2. logp)
642
+ end
643
+
644
+ function Base.:+ (
645
+ acc1:: ProduceLogLikelihoodAccumulator , acc2:: DynamicPPL.LogLikelihoodAccumulator
646
+ )
647
+ Libtask. produce (acc2. logp)
648
+ return ProduceLogLikelihoodAccumulator (acc1. logp + acc2. logp)
649
+ end
650
+
651
+ function Base. zero (acc:: ProduceLogLikelihoodAccumulator )
652
+ return ProduceLogLikelihoodAccumulator (zero (acc. logp))
653
+ end
654
+
655
+ function DynamicPPL. accumulate_assume!! (
656
+ acc:: ProduceLogLikelihoodAccumulator , val, logjac, vn, right
657
+ )
658
+ return acc
659
+ end
660
+ function DynamicPPL. accumulate_observe!! (
661
+ acc:: ProduceLogLikelihoodAccumulator , right, left, vn
662
+ )
663
+ return acc +
664
+ DynamicPPL. LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
665
+ end
666
+
667
+ function Base. convert (
668
+ :: Type{ProduceLogLikelihoodAccumulator{T}} , acc:: ProduceLogLikelihoodAccumulator
669
+ ) where {T}
670
+ return ProduceLogLikelihoodAccumulator (convert (T, acc. logp))
671
+ end
672
+ function DynamicPPL. convert_eltype (
673
+ :: Type{T} , acc:: ProduceLogLikelihoodAccumulator
674
+ ) where {T}
675
+ return ProduceLogLikelihoodAccumulator (convert (T, acc. logp))
676
+ end
0 commit comments