|
510 | 510 |
|
511 | 511 | Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value.
|
512 | 512 |
|
513 |
| -# Fields |
514 |
| -$(TYPEDFIELDS) |
515 |
| -""" |
516 |
| -struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} |
517 |
| - "the scalar log likelihood value" |
518 |
| - logp::T |
519 |
| -end |
520 |
| - |
521 |
| -# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two |
522 |
| -# can be used in a given VarInfo. |
523 |
| -DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood |
524 |
| -DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp |
525 |
| - |
526 |
| -function DynamicPPL.acclogp(acc1::ProduceLogLikelihoodAccumulator, val) |
527 |
| - # The below line is the only difference from `LogLikelihoodAccumulator`. |
528 |
| - Libtask.produce(val) |
529 |
| - return ProduceLogLikelihoodAccumulator(acc1.logp + val) |
530 |
| -end |
531 |
| - |
532 |
| -function DynamicPPL.accumulate_assume!!( |
533 |
| - acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right |
534 |
| -) |
535 |
| - return acc |
536 |
| -end |
537 |
| -function DynamicPPL.accumulate_observe!!( |
538 |
| - acc::ProduceLogLikelihoodAccumulator, right, left, vn |
539 |
| -) |
540 |
| - return DynamicPPL.acclogp(acc, Distributions.loglikelihood(right, left)) |
541 |
| -end |
542 |
| - |
543 |
| -# We need to tell Libtask which calls may have `produce` calls within them. In practice most |
544 |
| -# of these won't be needed, because of inlining and the fact that `might_produce` is only |
545 |
| -# called on `:invoke` expressions rather than `:call`s, but since those are implementation |
546 |
| -# details of the compiler, we set a bunch of methods as might_produce = true. We start with |
547 |
| -# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the |
548 |
| -# call stack. |
549 |
| -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true |
550 |
| -function Libtask.might_produce( |
551 |
| - ::Type{ |
552 |
| - <:Tuple{ |
553 |
| - typeof(Base.:+), |
554 |
| - ProduceLogLikelihoodAccumulator, |
555 |
| - DynamicPPL.LogLikelihoodAccumulator, |
556 |
| - }, |
557 |
| - }, |
558 |
| -) |
559 |
| - return true |
560 |
| -end |
561 |
| -function Libtask.might_produce( |
562 |
| - ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} |
563 |
| -) |
564 |
| - return true |
565 |
| -end |
566 |
| -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true |
567 |
| -# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? |
568 |
| -# That's the only thing that makes tilde_assume calls result in tilde_observe calls. |
569 |
| -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true |
570 |
| -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true |
571 |
| -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true |
572 |
| -function Libtask.might_produce( |
573 |
| - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} |
574 |
| -) |
575 |
| - return true |
576 |
| -end |
577 |
| -function Libtask.might_produce( |
578 |
| - ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} |
579 |
| -) |
580 |
| - return true |
581 |
| -end |
582 |
| -Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true |
583 |
| - |
584 |
| -""" |
585 |
| - ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator |
586 |
| -
|
587 |
| -Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase. |
588 |
| -
|
589 | 513 | # Fields
|
590 | 514 | $(TYPEDFIELDS)
|
591 | 515 | """
|
@@ -674,3 +598,44 @@ function DynamicPPL.convert_eltype(
|
674 | 598 | ) where {T}
|
675 | 599 | return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
|
676 | 600 | end
|
| 601 | + |
| 602 | +# We need to tell Libtask which calls may have `produce` calls within them. In practice most |
| 603 | +# of these won't be needed, because of inlining and the fact that `might_produce` is only |
| 604 | +# called on `:invoke` expressions rather than `:call`s, but since those are implementation |
| 605 | +# details of the compiler, we set a bunch of methods as might_produce = true. We start with |
| 606 | +# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the |
| 607 | +# call stack. |
| 608 | +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true |
| 609 | +function Libtask.might_produce( |
| 610 | + ::Type{ |
| 611 | + <:Tuple{ |
| 612 | + typeof(Base.:+), |
| 613 | + ProduceLogLikelihoodAccumulator, |
| 614 | + DynamicPPL.LogLikelihoodAccumulator, |
| 615 | + }, |
| 616 | + }, |
| 617 | +) |
| 618 | + return true |
| 619 | +end |
| 620 | +function Libtask.might_produce( |
| 621 | + ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} |
| 622 | +) |
| 623 | + return true |
| 624 | +end |
| 625 | +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true |
| 626 | +# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? |
| 627 | +# That's the only thing that makes tilde_assume calls result in tilde_observe calls. |
| 628 | +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true |
| 629 | +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true |
| 630 | +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true |
| 631 | +function Libtask.might_produce( |
| 632 | + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} |
| 633 | +) |
| 634 | + return true |
| 635 | +end |
| 636 | +function Libtask.might_produce( |
| 637 | + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} |
| 638 | +) |
| 639 | + return true |
| 640 | +end |
| 641 | +Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true |
0 commit comments