Skip to content

Commit bdcd72f

Browse files
committed
Fixes to ProduceLogLikelihoodAccumulator
1 parent da2b510 commit bdcd72f

File tree

1 file changed

+46
-41
lines changed

1 file changed

+46
-41
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ function DynamicPPL.initialstep(
175175
kwargs...,
176176
)
177177
# Reset the VarInfo.
178+
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
178179
vi = DynamicPPL.reset_num_produce!!(vi)
179180
DynamicPPL.set_retained_vns_del!(vi)
180181
vi = DynamicPPL.resetlogp!!(vi)
@@ -483,9 +484,9 @@ function DynamicPPL.tilde_observe!!(
483484
# TODO(mhauru) Rather than this if-block, we should use try-catch within
484485
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
485486
# hence this.
486-
# if !using_local_vi
487-
# set_trace_local_varinfo_maybe(vi)
488-
# end
487+
if !using_local_vi
488+
set_trace_local_varinfo_maybe(vi)
489+
end
489490
return left, vi
490491
end
491492

@@ -504,47 +505,10 @@ function AdvancedPS.Trace(
504505
return newtrace
505506
end
506507

507-
# We need to tell Libtask which calls may have `produce` calls within them. In practice most
508-
# of these won't be needed, because of inlining and the fact that `might_produce` is only
509-
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
510-
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
511-
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
512-
# call stack.
513-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
514-
function Libtask.might_produce(
515-
::Type{
516-
<:Tuple{
517-
typeof(Base.:+),
518-
ProduceLogLikelihoodAccumulator,
519-
DynamicPPL.LogLikelihoodAccumulator,
520-
},
521-
},
522-
)
523-
return true
524-
end
525-
function Libtask.might_produce(
526-
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
527-
)
528-
return true
529-
end
530-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
531-
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
532-
function Libtask.might_produce(
533-
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
534-
)
535-
return true
536-
end
537-
function Libtask.might_produce(
538-
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
539-
)
540-
return true
541-
end
542-
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
543-
544508
"""
545509
ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
546510
547-
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
511+
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value.
548512
549513
# Fields
550514
$(TYPEDFIELDS)
@@ -634,3 +598,44 @@ function DynamicPPL.convert_eltype(
634598
) where {T}
635599
return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
636600
end
601+
602+
# We need to tell Libtask which calls may have `produce` calls within them. In practice most
603+
# of these won't be needed, because of inlining and the fact that `might_produce` is only
604+
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
605+
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
606+
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
607+
# call stack.
608+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
609+
function Libtask.might_produce(
610+
::Type{
611+
<:Tuple{
612+
typeof(Base.:+),
613+
ProduceLogLikelihoodAccumulator,
614+
DynamicPPL.LogLikelihoodAccumulator,
615+
},
616+
},
617+
)
618+
return true
619+
end
620+
function Libtask.might_produce(
621+
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
622+
)
623+
return true
624+
end
625+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
626+
# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext?
627+
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
628+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true
629+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true
630+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
631+
function Libtask.might_produce(
632+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
633+
)
634+
return true
635+
end
636+
function Libtask.might_produce(
637+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
638+
)
639+
return true
640+
end
641+
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

0 commit comments

Comments
 (0)