Skip to content

Commit 8cb2e19

Browse files
committed
WIP PMCMC work
1 parent b0df6a6 commit 8cb2e19

File tree

1 file changed

+118
-26
lines changed

1 file changed

+118
-26
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ function AdvancedPS.advance!(
4343
# Make sure we load/reset the rng in the new replaying mechanism
4444
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
4545
score = consume(trace.model.ctask)
46-
if score === nothing
47-
return nothing
48-
else
49-
return score + DynamicPPL.getlogjoint(trace.model.f.varinfo)
50-
end
46+
return score
5147
end
5248

5349
function AdvancedPS.delete_retained!(trace::TracedModel)
@@ -114,11 +110,7 @@ end
114110

115111
function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight)
116112
theta = getparams(model, vi)
117-
118-
# This is pretty useless since we reset the log probability continuously in the
119-
# particle sweep.
120113
lp = DynamicPPL.getlogjoint(vi)
121-
122114
return SMCTransition(theta, lp, weight)
123115
end
124116

@@ -293,11 +285,7 @@ varinfo(state::PGState) = state.vi
293285

294286
function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence)
295287
theta = getparams(model, vi)
296-
297-
# This is pretty useless since we reset the log probability continuously in the
298-
# particle sweep.
299288
lp = DynamicPPL.getlogjoint(vi)
300-
301289
return PGTransition(theta, lp, logevidence)
302290
end
303291

@@ -316,6 +304,7 @@ function DynamicPPL.initialstep(
316304
vi::AbstractVarInfo;
317305
kwargs...,
318306
)
307+
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
319308
# Reset the VarInfo before new sweep
320309
vi = DynamicPPL.reset_num_produce!!(vi)
321310
DynamicPPL.set_retained_vns_del!(vi)
@@ -471,10 +460,7 @@ function DynamicPPL.assume(
471460
r = vi[vn]
472461
end
473462

474-
# TODO(mhauru) This get/set business is awful.
475-
old_logp = DynamicPPL.getlogprior(vi)
476463
vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist)
477-
vi = DynamicPPL.setlogprior!!(vi, old_logp)
478464

479465
# TODO(mhauru) Rather than this if-block, we should use try-catch within
480466
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
@@ -492,20 +478,14 @@ function DynamicPPL.tilde_observe!!(
492478
vi = get_trace_local_varinfo_maybe(vi)
493479
using_local_vi = objectid(vi) == arg_vi_id
494480

495-
# TODO(mhauru) This get/set business is awful.
496-
old_logp = DynamicPPL.getloglikelihood(vi)
497481
left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi)
498-
new_loglikelihood = DynamicPPL.getloglikelihood(vi) - old_logp
499-
vi = DynamicPPL.setloglikelihood!!(vi, old_logp)
500482

501483
# TODO(mhauru) Rather than this if-block, we should use try-catch within
502484
# `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block,
503485
# hence this.
504-
if !using_local_vi
505-
set_trace_local_varinfo_maybe(vi)
506-
end
507-
508-
Libtask.produce(new_loglikelihood)
486+
# if !using_local_vi
487+
# set_trace_local_varinfo_maybe(vi)
488+
# end
509489
return left, vi
510490
end
511491

@@ -528,7 +508,25 @@ end
528508
# of these won't be needed, because of inlining and the fact that `might_produce` is only
529509
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
530510
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
531-
# `acclogp_observe!!` which is what calls `produce` and go up the call stack.
511+
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
512+
# call stack.
513+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
514+
function Libtask.might_produce(
515+
::Type{
516+
<:Tuple{
517+
typeof(Base.:+),
518+
ProduceLogLikelihoodAccumulator,
519+
DynamicPPL.LogLikelihoodAccumulator,
520+
},
521+
},
522+
)
523+
return true
524+
end
525+
function Libtask.might_produce(
526+
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
527+
)
528+
return true
529+
end
532530
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
533531
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
534532
function Libtask.might_produce(
@@ -542,3 +540,97 @@ function Libtask.might_produce(
542540
return true
543541
end
544542
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
543+
544+
"""
545+
ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
546+
547+
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
548+
549+
# Fields
550+
$(TYPEDFIELDS)
551+
"""
552+
struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.AbstractAccumulator
553+
"the scalar log likelihood value"
554+
logp::T
555+
end
556+
557+
"""
558+
ProduceLogLikelihoodAccumulator{T}()
559+
560+
Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero.
561+
"""
562+
ProduceLogLikelihoodAccumulator{T}() where {T<:Real} =
563+
ProduceLogLikelihoodAccumulator(zero(T))
564+
function ProduceLogLikelihoodAccumulator()
565+
return ProduceLogLikelihoodAccumulator{DynamicPPL.LogProbType}()
566+
end
567+
568+
Base.copy(acc::ProduceLogLikelihoodAccumulator) = acc
569+
570+
function Base.show(io::IO, acc::ProduceLogLikelihoodAccumulator)
571+
return print(io, "ProduceLogLikelihoodAccumulator($(repr(acc.logp)))")
572+
end
573+
574+
function Base.:(==)(
575+
acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator
576+
)
577+
return acc1.logp == acc2.logp
578+
end
579+
580+
function Base.isequal(
581+
acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator
582+
)
583+
return isequal(acc1.logp, acc2.logp)
584+
end
585+
586+
function Base.hash(acc::ProduceLogLikelihoodAccumulator, h::UInt)
587+
return hash((ProduceLogLikelihoodAccumulator, acc.logp), h)
588+
end
589+
590+
# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two
591+
# can be used in a given VarInfo.
592+
DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood
593+
594+
function DynamicPPL.split(::ProduceLogLikelihoodAccumulator{T}) where {T}
595+
return ProduceLogLikelihoodAccumulator(zero(T))
596+
end
597+
598+
function DynamicPPL.combine(
599+
acc::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator
600+
)
601+
return ProduceLogLikelihoodAccumulator(acc.logp + acc2.logp)
602+
end
603+
604+
function Base.:+(
605+
acc1::ProduceLogLikelihoodAccumulator, acc2::DynamicPPL.LogLikelihoodAccumulator
606+
)
607+
Libtask.produce(acc2.logp)
608+
return ProduceLogLikelihoodAccumulator(acc1.logp + acc2.logp)
609+
end
610+
611+
function Base.zero(acc::ProduceLogLikelihoodAccumulator)
612+
return ProduceLogLikelihoodAccumulator(zero(acc.logp))
613+
end
614+
615+
function DynamicPPL.accumulate_assume!!(
616+
acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right
617+
)
618+
return acc
619+
end
620+
function DynamicPPL.accumulate_observe!!(
621+
acc::ProduceLogLikelihoodAccumulator, right, left, vn
622+
)
623+
return acc +
624+
DynamicPPL.LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
625+
end
626+
627+
function Base.convert(
628+
::Type{ProduceLogLikelihoodAccumulator{T}}, acc::ProduceLogLikelihoodAccumulator
629+
) where {T}
630+
return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
631+
end
632+
function DynamicPPL.convert_eltype(
633+
::Type{T}, acc::ProduceLogLikelihoodAccumulator
634+
) where {T}
635+
return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
636+
end

0 commit comments

Comments
 (0)