Skip to content

Commit 213a55b

Browse files
mhaurupenelopeysm
authored andcommitted
WIP PMCMC work
1 parent c062867 commit 213a55b

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,3 +580,97 @@ function Libtask.might_produce(
580580
return true
581581
end
582582
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
583+
584+
"""
585+
ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
586+
587+
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on every increase.
588+
589+
# Fields
590+
$(TYPEDFIELDS)
591+
"""
592+
struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.AbstractAccumulator
593+
"the scalar log likelihood value"
594+
logp::T
595+
end
596+
597+
"""
598+
ProduceLogLikelihoodAccumulator{T}()
599+
600+
Create a new `ProduceLogLikelihoodAccumulator` accumulator with the log likelihood of zero.
601+
"""
602+
ProduceLogLikelihoodAccumulator{T}() where {T<:Real} =
603+
ProduceLogLikelihoodAccumulator(zero(T))
604+
function ProduceLogLikelihoodAccumulator()
605+
return ProduceLogLikelihoodAccumulator{DynamicPPL.LogProbType}()
606+
end
607+
608+
Base.copy(acc::ProduceLogLikelihoodAccumulator) = acc
609+
610+
function Base.show(io::IO, acc::ProduceLogLikelihoodAccumulator)
611+
return print(io, "ProduceLogLikelihoodAccumulator($(repr(acc.logp)))")
612+
end
613+
614+
function Base.:(==)(
615+
acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator
616+
)
617+
return acc1.logp == acc2.logp
618+
end
619+
620+
function Base.isequal(
621+
acc1::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator
622+
)
623+
return isequal(acc1.logp, acc2.logp)
624+
end
625+
626+
function Base.hash(acc::ProduceLogLikelihoodAccumulator, h::UInt)
627+
return hash((ProduceLogLikelihoodAccumulator, acc.logp), h)
628+
end
629+
630+
# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two
631+
# can be used in a given VarInfo.
632+
DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood
633+
634+
function DynamicPPL.split(::ProduceLogLikelihoodAccumulator{T}) where {T}
635+
return ProduceLogLikelihoodAccumulator(zero(T))
636+
end
637+
638+
function DynamicPPL.combine(
639+
acc::ProduceLogLikelihoodAccumulator, acc2::ProduceLogLikelihoodAccumulator
640+
)
641+
return ProduceLogLikelihoodAccumulator(acc.logp + acc2.logp)
642+
end
643+
644+
function Base.:+(
645+
acc1::ProduceLogLikelihoodAccumulator, acc2::DynamicPPL.LogLikelihoodAccumulator
646+
)
647+
Libtask.produce(acc2.logp)
648+
return ProduceLogLikelihoodAccumulator(acc1.logp + acc2.logp)
649+
end
650+
651+
function Base.zero(acc::ProduceLogLikelihoodAccumulator)
652+
return ProduceLogLikelihoodAccumulator(zero(acc.logp))
653+
end
654+
655+
function DynamicPPL.accumulate_assume!!(
656+
acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right
657+
)
658+
return acc
659+
end
660+
function DynamicPPL.accumulate_observe!!(
661+
acc::ProduceLogLikelihoodAccumulator, right, left, vn
662+
)
663+
return acc +
664+
DynamicPPL.LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
665+
end
666+
667+
function Base.convert(
668+
::Type{ProduceLogLikelihoodAccumulator{T}}, acc::ProduceLogLikelihoodAccumulator
669+
) where {T}
670+
return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
671+
end
672+
function DynamicPPL.convert_eltype(
673+
::Type{T}, acc::ProduceLogLikelihoodAccumulator
674+
) where {T}
675+
return ProduceLogLikelihoodAccumulator(convert(T, acc.logp))
676+
end

0 commit comments

Comments
 (0)