Skip to content

Commit 2dc61a8

Browse files
committed
Improve acc tests
1 parent bc51d62 commit 2dc61a8

File tree

3 files changed

+106
-18
lines changed

3 files changed

+106
-18
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export AbstractVarInfo,
6666
setlogprior!!,
6767
setlogjac!!,
6868
setloglikelihood!!,
69+
acclogp,
6970
acclogp!!,
7071
acclogjac!!,
7172
acclogprior!!,

src/accumulators.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,9 @@ Implemented as a generated function to enabled constant propagation of the resul
203203
@generated function _joint_keys(
204204
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
205205
) where {names1,names2}
206-
set_names1 = Set(names1)
207-
set_names2 = Set(names2)
208-
only_in_nt1 = tuple(setdiff(set_names1, set_names2)...)
209-
only_in_nt2 = tuple(setdiff(set_names2, set_names1)...)
210-
in_both = tuple(intersect(set_names1, set_names2)...)
206+
only_in_nt1 = tuple(setdiff(names1, names2)...)
207+
only_in_nt2 = tuple(setdiff(names2, names1)...)
208+
in_both = tuple(intersect(names1, names2)...)
211209
return :($only_in_nt1, $only_in_nt2, $in_both)
212210
end
213211

@@ -226,7 +224,7 @@ function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
226224
accs_in_both = (
227225
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
228226
)
229-
return AccumulatorTuple(accs_in_at1..., accs_in_at2..., accs_in_both...)
227+
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
230228
end
231229

232230
"""

test/accumulators.jl

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ using DynamicPPL:
109109
VariableOrderAccumulator(2)
110110
end
111111

112-
@testset "merge and subset" begin
112+
@testset "merge" begin
113113
@test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) ==
114-
LogPriorAccumulator(3.0)
114+
LogPriorAccumulator(2.0)
115115
@test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) ==
116-
LogJacobianAccumulator(3.0)
116+
LogJacobianAccumulator(2.0)
117117
@test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) ==
118-
LogLikelihoodAccumulator(3.0)
118+
LogLikelihoodAccumulator(2.0)
119119

120120
@test merge(
121121
VariableOrderAccumulator(1, Dict{VarName,Int}()),
@@ -132,6 +132,49 @@ using DynamicPPL:
132132
1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3))
133133
)
134134
end
135+
136+
@testset "subset" begin
137+
@test subset(LogPriorAccumulator(1.0), VarName[]) == LogPriorAccumulator(1.0)
138+
@test subset(LogJacobianAccumulator(1.0), VarName[]) ==
139+
LogJacobianAccumulator(1.0)
140+
@test subset(LogLikelihoodAccumulator(1.0), VarName[]) ==
141+
LogLikelihoodAccumulator(1.0)
142+
143+
@test subset(
144+
VariableOrderAccumulator(1, Dict{VarName,Int}()),
145+
VarName[@varname(a), @varname(b)],
146+
) == VariableOrderAccumulator(1, Dict{VarName,Int}())
147+
@test subset(
148+
VariableOrderAccumulator(
149+
2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2))
150+
),
151+
VarName[@varname(a)],
152+
) == VariableOrderAccumulator(2, Dict{VarName,Int}((@varname(a) => 1)))
153+
@test subset(
154+
VariableOrderAccumulator(
155+
2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2))
156+
),
157+
VarName[],
158+
) == VariableOrderAccumulator(2, Dict{VarName,Int}())
159+
@test subset(
160+
VariableOrderAccumulator(
161+
2,
162+
Dict{VarName,Int}((
163+
@varname(a) => 1,
164+
@varname(a.b.c) => 2,
165+
@varname(a.b.c.d[1]) => 2,
166+
@varname(b) => 3,
167+
@varname(c[1]) => 4,
168+
)),
169+
),
170+
VarName[@varname(a.b), @varname(b)],
171+
) == VariableOrderAccumulator(
172+
2,
173+
Dict{VarName,Int}((
174+
@varname(a.b.c) => 2, @varname(a.b.c.d[1]) => 2, @varname(b) => 3
175+
)),
176+
)
177+
end
135178
end
136179

137180
@testset "accumulator tuples" begin
@@ -140,7 +183,7 @@ using DynamicPPL:
140183
lp_f32 = LogPriorAccumulator(1.0f0)
141184
ll_f64 = LogLikelihoodAccumulator(1.0)
142185
ll_f32 = LogLikelihoodAccumulator(1.0f0)
143-
np_i64 = VariableOrderAccumulator(1)
186+
vo_i64 = VariableOrderAccumulator(1)
144187

145188
@testset "constructors" begin
146189
@test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64))
@@ -154,22 +197,22 @@ using DynamicPPL:
154197
end
155198

156199
@testset "basic operations" begin
157-
at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64)
200+
at_all64 = AccumulatorTuple(lp_f64, ll_f64, vo_i64)
158201

159202
@test at_all64[:LogPrior] == lp_f64
160203
@test at_all64[:LogLikelihood] == ll_f64
161-
@test at_all64[:VariableOrder] == np_i64
204+
@test at_all64[:VariableOrder] == vo_i64
162205

163-
@test haskey(AccumulatorTuple(np_i64), Val(:VariableOrder))
164-
@test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior))
165-
@test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3
206+
@test haskey(AccumulatorTuple(vo_i64), Val(:VariableOrder))
207+
@test ~haskey(AccumulatorTuple(vo_i64), Val(:LogPrior))
208+
@test length(AccumulatorTuple(lp_f64, ll_f64, vo_i64)) == 3
166209
@test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder)
167-
@test collect(at_all64) == [lp_f64, ll_f64, np_i64]
210+
@test collect(at_all64) == [lp_f64, ll_f64, vo_i64]
168211

169212
# Replace the existing LogPriorAccumulator
170213
@test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32
171214
# Check that setacc!! didn't modify the original
172-
@test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64)
215+
@test at_all64 == AccumulatorTuple(lp_f64, ll_f64, vo_i64)
173216
# Add a new accumulator type.
174217
@test setacc!!(AccumulatorTuple(lp_f64), ll_f64) ==
175218
AccumulatorTuple(lp_f64, ll_f64)
@@ -197,6 +240,52 @@ using DynamicPPL:
197240
acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood)
198241
) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0))
199242
end
243+
244+
@testset "merge" begin
245+
vo1 = VariableOrderAccumulator(
246+
1, Dict{VarName,Int}(@varname(a) => 1, @varname(b) => 1)
247+
)
248+
vo2 = VariableOrderAccumulator(
249+
2, Dict{VarName,Int}(@varname(a) => 2, @varname(c) => 2)
250+
)
251+
accs1 = AccumulatorTuple(lp_f64, ll_f64, vo1)
252+
accs2 = AccumulatorTuple(lp_f32, vo2)
253+
@test merge(accs1, accs2) == AccumulatorTuple(
254+
ll_f64,
255+
lp_f32,
256+
VariableOrderAccumulator(
257+
2,
258+
Dict{VarName,Int}(@varname(a) => 2, @varname(b) => 1, @varname(c) => 2),
259+
),
260+
)
261+
@test merge(AccumulatorTuple(), accs1) == accs1
262+
@test merge(accs1, AccumulatorTuple()) == accs1
263+
@test merge(accs1, accs1) == accs1
264+
end
265+
266+
@testset "subset" begin
267+
accs = AccumulatorTuple(
268+
lp_f64,
269+
ll_f64,
270+
VariableOrderAccumulator(
271+
1,
272+
Dict{VarName,Int}(
273+
@varname(a.b) => 1, @varname(a.b[1]) => 2, @varname(b) => 1
274+
),
275+
),
276+
)
277+
278+
@test subset(accs, VarName[]) == AccumulatorTuple(
279+
lp_f64, ll_f64, VariableOrderAccumulator(1, Dict{VarName,Int}())
280+
)
281+
@test subset(accs, VarName[@varname(a)]) == AccumulatorTuple(
282+
lp_f64,
283+
ll_f64,
284+
VariableOrderAccumulator(
285+
1, Dict{VarName,Int}(@varname(a.b) => 1, @varname(a.b[1]) => 2)
286+
),
287+
)
288+
end
200289
end
201290
end
202291

0 commit comments

Comments
 (0)