Skip to content

Commit f983da5

Browse files
committed
Replace + with accumulate for LogProbAccs
1 parent e9bf50b commit f983da5

File tree

3 files changed

+12
-24
lines changed

3 files changed

+12
-24
lines changed

src/abstract_varinfo.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ Add `logp` to the value of the log of the prior probability in `vi`.
358358
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref).
359359
"""
360360
function acclogprior!!(vi::AbstractVarInfo, logp)
361-
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
361+
return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior))
362362
end
363363

364364
"""
@@ -369,9 +369,7 @@ Add `logJ` to the value of the log Jacobian in `vi`.
369369
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
370370
"""
371371
function acclogjac!!(vi::AbstractVarInfo, logJ)
372-
return map_accumulator!!(
373-
acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian)
374-
)
372+
return map_accumulator!!(acc -> acclogp(acc, logJ), vi, Val(:LogJacobian))
375373
end
376374

377375
"""
@@ -382,9 +380,7 @@ Add `logp` to the value of the log of the likelihood in `vi`.
382380
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref).
383381
"""
384382
function accloglikelihood!!(vi::AbstractVarInfo, logp)
385-
return map_accumulator!!(
386-
acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood)
387-
)
383+
return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood))
388384
end
389385

390386
"""

src/default_accumulators.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,7 @@ function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator)
5757
return basetypeof(acc)(logp(acc) + logp(acc2))
5858
end
5959

60-
function Base.:+(acc1::LogProbAccumulator, acc2::LogProbAccumulator)
61-
if basetypeof(acc1) !== basetypeof(acc2)
62-
msg = "Cannot add accumulators of different types: $(basetypeof(acc1)) and $(basetypeof(acc2))"
63-
throw(ArgumentError(msg))
64-
end
65-
return basetypeof(acc1)(logp(acc1) + logp(acc2))
66-
end
60+
acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val)
6761

6862
Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc)))
6963

@@ -99,7 +93,7 @@ logp(acc::LogPriorAccumulator) = acc.logp
9993
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
10094

10195
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
102-
return acc + LogPriorAccumulator(logpdf(right, val))
96+
return acclogp(acc, logpdf(right, val))
10397
end
10498
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
10599

@@ -143,7 +137,7 @@ logp(acc::LogJacobianAccumulator) = acc.logJ
143137
accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian
144138

145139
function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right)
146-
return acc + LogJacobianAccumulator(logjac)
140+
return acclogp(acc, logjac)
147141
end
148142
accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc
149143

@@ -169,7 +163,7 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
169163
# Note that it's important to use the loglikelihood function here, not logpdf, because
170164
# they handle vectors differently:
171165
# https://github.com/JuliaStats/Distributions.jl/issues/1972
172-
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
166+
return acclogp(acc, Distributions.loglikelihood(right, left))
173167
end
174168

175169
"""
@@ -208,7 +202,7 @@ end
208202

209203
function Base.show(io::IO, acc::VariableOrderAccumulator)
210204
return print(
211-
io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))"
205+
io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))"
212206
)
213207
end
214208

test/accumulators.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,11 @@ using DynamicPPL:
3939
end
4040

4141
@testset "addition and incrementation" begin
42-
@test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) ==
43-
LogPriorAccumulator(2.0f0)
44-
@test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) ==
45-
LogPriorAccumulator(2.0)
46-
@test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) ==
42+
@test acclogp(LogPriorAccumulator(1.0f0), 1.0f0) == LogPriorAccumulator(2.0f0)
43+
@test acclogp(LogPriorAccumulator(1.0), 1.0f0) == LogPriorAccumulator(2.0)
44+
@test acclogp(LogLikelihoodAccumulator(1.0f0), 1.0f0) ==
4745
LogLikelihoodAccumulator(2.0f0)
48-
@test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) ==
46+
@test acclogp(LogLikelihoodAccumulator(1.0), 1.0f0) ==
4947
LogLikelihoodAccumulator(2.0)
5048
@test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1)
5149
@test increment(VariableOrderAccumulator{UInt8}()) ==

0 commit comments

Comments
 (0)