Skip to content

Commit 6866091

Browse files
committed
Add a test, plus a bunch of == methods
1 parent 92f72d0 commit 6866091

File tree

8 files changed

+95
-9
lines changed

8 files changed

+95
-9
lines changed

src/debug_utils.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,22 @@ end
156156
const _DEBUG_ACC_NAME = :Debug
157157
DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME
158158

159+
function Base.:(==)(acc1::DebugAccumulator, acc2::DebugAccumulator)
160+
return (
161+
acc1.varnames_seen == acc2.varnames_seen &&
162+
acc1.statements == acc2.statements &&
163+
acc1.error_on_failure == acc2.error_on_failure
164+
)
165+
end
166+
159167
function _zero(acc::DebugAccumulator)
160168
return DebugAccumulator(
161169
OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure
162170
)
163171
end
164-
reset(acc::DebugAccumulator) = _zero(acc)
165-
split(acc::DebugAccumulator) = _zero(acc)
166-
function combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
172+
DynamicPPL.reset(acc::DebugAccumulator) = _zero(acc)
173+
DynamicPPL.split(acc::DebugAccumulator) = _zero(acc)
174+
function DynamicPPL.combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
167175
return DebugAccumulator(
168176
merge(acc1.varnames_seen, acc2.varnames_seen),
169177
vcat(acc1.statements, acc2.statements),

src/extract_priors.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ end
1010

1111
accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator
1212

13+
function Base.:(==)(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator)
14+
return acc1.priors == acc2.priors
15+
end
16+
1317
_zero(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
1418
reset(acc::PriorDistributionAccumulator) = _zero(acc)
1519
split(acc::PriorDistributionAccumulator) = _zero(acc)

src/pointwise_logdensities.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
3232
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
3333
end
3434

35+
function Base.:(==)(
36+
acc1::PointwiseLogProbAccumulator{wlp1}, acc2::PointwiseLogProbAccumulator{wlp2}
37+
) where {wlp1,wlp2}
38+
return (wlp1 == wlp2 && acc1.logps == acc2.logps)
39+
end
40+
3541
function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
3642
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
3743
end

src/values_as_in_model.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ function ValuesAsInModelAccumulator(include_colon_eq)
2020
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
2121
end
2222

23+
function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
24+
return (acc1.include_colon_eq == acc2.include_colon_eq && acc1.values == acc2.values)
25+
end
26+
2327
function Base.copy(acc::ValuesAsInModelAccumulator)
2428
return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq)
2529
end

test/accumulators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using DynamicPPL:
1313
convert_eltype,
1414
getacc,
1515
map_accumulator,
16+
reset,
1617
setacc!!,
1718
split
1819

test/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@
204204
end
205205

206206
# Reset the logp accumulators.
207-
svi_eval = DynamicPPL.resetlogp!!(svi_eval)
207+
svi_eval = DynamicPPL.resetaccs!!(svi_eval)
208208

209209
# Compute `logjoint` using the varinfo.
210210
logπ = logjoint(model, svi_eval)

test/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
@test getlogjoint(vi) == lp
2626
@test getlogjoint(threadsafe_vi) == lp + 42
2727

28-
threadsafe_vi = resetlogp!!(threadsafe_vi)
28+
threadsafe_vi = DynamicPPL.resetaccs!!(threadsafe_vi)
2929
@test iszero(getlogjoint(threadsafe_vi))
3030
expected_accs = DynamicPPL.AccumulatorTuple(
3131
(DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))...

test/varinfo.jl

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ end
112112
test_base(SimpleVarInfo(DynamicPPL.VarNamedVector()))
113113
end
114114

115-
@testset "get/set/acc/resetlogp" begin
115+
@testset "get/set/acclogp" begin
116116
function test_varinfo_logp!(vi)
117117
@test DynamicPPL.getlogjoint(vi) === 0.0
118118
vi = DynamicPPL.setlogprior!!(vi, 1.0)
@@ -131,8 +131,6 @@ end
131131
@test DynamicPPL.getlogprior(vi) === 2.0
132132
@test DynamicPPL.getloglikelihood(vi) === 2.0
133133
@test DynamicPPL.getlogjoint(vi) === 4.0
134-
vi = DynamicPPL.resetlogp!!(vi)
135-
@test DynamicPPL.getlogjoint(vi) === 0.0
136134
end
137135

138136
vi = VarInfo()
@@ -143,7 +141,7 @@ end
143141
test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector()))
144142
end
145143

146-
@testset "accumulators" begin
144+
@testset "logp accumulators" begin
147145
@model function demo()
148146
a ~ Normal()
149147
b ~ Normal()
@@ -227,6 +225,71 @@ end
227225
@test_throws r"has no field `?LogPrior" getlogjoint(vi)
228226
end
229227

228+
@testset "resetaccs" begin
229+
# Put in a bunch of accumulators, check that they're all reset either
230+
# when we call resetaccs!!, empty!!, or evaluate!!.
231+
@model function demo()
232+
a ~ Normal()
233+
return x ~ Normal(a)
234+
end
235+
model = demo()
236+
vi_orig = VarInfo(model)
237+
# It already has the logp accumulators, so let's add in some more.
238+
vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.DebugUtils.DebugAccumulator(true))
239+
vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.ValuesAsInModelAccumulator(true))
240+
vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.PriorDistributionAccumulator())
241+
vi_orig = DynamicPPL.setacc!!(
242+
vi_orig, DynamicPPL.PointwiseLogProbAccumulator{:both}()
243+
)
244+
# And evaluate the model once so that they are populated.
245+
_, vi_orig = DynamicPPL.evaluate!!(model, vi_orig)
246+
247+
function all_accs_empty(vi::AbstractVarInfo)
248+
for acc_key in keys(DynamicPPL.getaccs(vi))
249+
acc = DynamicPPL.getacc(vi, Val(acc_key))
250+
acc == DynamicPPL.reset(acc) || return false
251+
end
252+
return true
253+
end
254+
255+
@test !all_accs_empty(vi_orig)
256+
257+
vi = DynamicPPL.resetaccs!!(deepcopy(vi_orig))
258+
@test all_accs_empty(vi)
259+
@test getlogjoint(vi) == 0.0 # for good measure
260+
@test getlogprior(vi) == 0.0
261+
@test getloglikelihood(vi) == 0.0
262+
263+
vi = DynamicPPL.empty!!(deepcopy(vi_orig))
264+
@test all_accs_empty(vi)
265+
@test getlogjoint(vi) == 0.0
266+
@test getlogprior(vi) == 0.0
267+
@test getloglikelihood(vi) == 0.0
268+
269+
function all_accs_same(vi1::AbstractVarInfo, vi2::AbstractVarInfo)
270+
# Check that they have the same accs
271+
keys1 = Set(keys(DynamicPPL.getaccs(vi1)))
272+
keys2 = Set(keys(DynamicPPL.getaccs(vi2)))
273+
keys1 == keys2 || return false
274+
# Check that they have the same values
275+
for acc_key in keys1
276+
acc1 = DynamicPPL.getacc(vi1, Val(acc_key))
277+
acc2 = DynamicPPL.getacc(vi2, Val(acc_key))
278+
if acc1 != acc2
279+
@show acc1, acc2
280+
end
281+
acc1 == acc2 || return false
282+
end
283+
return true
284+
end
285+
# Hopefully this doesn't matter
286+
@test all_accs_same(vi_orig, deepcopy(vi_orig))
287+
# If we re-evaluate, then we expect the accs to be reset prior to evaluation.
288+
# Thus after re-evaluation, the accs should be exactly the same as before.
289+
_, vi = DynamicPPL.evaluate!!(model, deepcopy(vi_orig))
290+
@test all_accs_same(vi, vi_orig)
291+
end
292+
230293
@testset "flags" begin
231294
# Test flag setting:
232295
# is_flagged, set_flag!, unset_flag!

0 commit comments

Comments
 (0)