Skip to content

Commit 97774f9

Browse files
mhaurupenelopeysm
authored andcommitted
Fixes to ProduceLogLikelihoodAccumulator
1 parent 213a55b commit 97774f9

File tree

1 file changed

+41
-76
lines changed

1 file changed

+41
-76
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 41 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -510,82 +510,6 @@ end
510510
511511
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value.
512512
513-
# Fields
514-
$(TYPEDFIELDS)
515-
"""
516-
struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T}
517-
"the scalar log likelihood value"
518-
logp::T
519-
end
520-
521-
# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two
522-
# can be used in a given VarInfo.
523-
DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood
524-
DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp
525-
526-
function DynamicPPL.acclogp(acc1::ProduceLogLikelihoodAccumulator, val)
527-
# The below line is the only difference from `LogLikelihoodAccumulator`.
528-
Libtask.produce(val)
529-
return ProduceLogLikelihoodAccumulator(acc1.logp + val)
530-
end
531-
532-
function DynamicPPL.accumulate_assume!!(
533-
acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right
534-
)
535-
return acc
536-
end
537-
function DynamicPPL.accumulate_observe!!(
538-
acc::ProduceLogLikelihoodAccumulator, right, left, vn
539-
)
540-
return DynamicPPL.acclogp(acc, Distributions.loglikelihood(right, left))
541-
end
542-
543-
# We need to tell Libtask which calls may have `produce` calls within them. In practice most
544-
# of these won't be needed, because of inlining and the fact that `might_produce` is only
545-
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
546-
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
547-
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
548-
# call stack.
549-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
550-
function Libtask.might_produce(
551-
::Type{
552-
<:Tuple{
553-
typeof(Base.:+),
554-
ProduceLogLikelihoodAccumulator,
555-
DynamicPPL.LogLikelihoodAccumulator,
556-
},
557-
},
558-
)
559-
return true
560-
end
561-
function Libtask.might_produce(
562-
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
563-
)
564-
return true
565-
end
566-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
567-
# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext?
568-
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
569-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true
570-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true
571-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
572-
function Libtask.might_produce(
573-
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
574-
)
575-
return true
576-
end
577-
function Libtask.might_produce(
578-
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
579-
)
580-
return true
581-
end
582-
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
583-
584-
"""
585-
ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
586-
587-
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
588-
589513
# Fields
590514
$(TYPEDFIELDS)
591515
"""
@@ -674,3 +598,44 @@ function DynamicPPL.convert_eltype(
674598
) where {T}
675599
return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
676600
end
601+
602+
# We need to tell Libtask which calls may have `produce` calls within them. In practice most
603+
# of these won't be needed, because of inlining and the fact that `might_produce` is only
604+
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
605+
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
606+
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
607+
# call stack.
608+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
609+
function Libtask.might_produce(
610+
::Type{
611+
<:Tuple{
612+
typeof(Base.:+),
613+
ProduceLogLikelihoodAccumulator,
614+
DynamicPPL.LogLikelihoodAccumulator,
615+
},
616+
},
617+
)
618+
return true
619+
end
620+
function Libtask.might_produce(
621+
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
622+
)
623+
return true
624+
end
625+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
626+
# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext?
627+
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
628+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true
629+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true
630+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
631+
function Libtask.might_produce(
632+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
633+
)
634+
return true
635+
end
636+
function Libtask.might_produce(
637+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
638+
)
639+
return true
640+
end
641+
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

0 commit comments

Comments
 (0)