Skip to content

Commit b5573a3

Browse files
committed
Add tests for ConditionContext/decondition_context
1 parent c728e95 commit b5573a3

File tree

2 files changed

+82
-29
lines changed

2 files changed

+82
-29
lines changed

src/contexts.jl

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,18 @@ const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
331331
const DictConditionContext = ConditionContext{<:AbstractDict}
332332

333333
# Use DefaultContext as the default base context
334-
ConditionContext(values::Union{NamedTuple,AbstractDict}) = ConditionContext(values, DefaultContext())
334+
function ConditionContext(values::Union{NamedTuple,AbstractDict})
335+
return ConditionContext(values, DefaultContext())
336+
end
335337
# Optimisation when there are no values to condition on
336338
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
337-
# Try to avoid nested `ConditionContext`.
339+
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
340+
# values inside the child context, thus giving precedence to the outermost
341+
# `ConditionContext`.
338342
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
339-
# Note that this potentially overrides values from `context`, thus giving
340-
# precedence to the outmost `ConditionContext`.
343+
return ConditionContext(merge(context.values, values), childcontext(context))
344+
end
345+
function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext)
341346
return ConditionContext(merge(context.values, values), childcontext(context))
342347
end
343348

@@ -433,32 +438,17 @@ end
433438
function decondition_context(context::ConditionContext)
434439
return decondition_context(childcontext(context))
435440
end
436-
function decondition_context(context::ConditionContext, sym)
437-
return ConditionContext(
438-
decondition_context(childcontext(context), sym), BangBang.delete!!(context.values, sym)
439-
)
440-
end
441441
function decondition_context(context::ConditionContext, sym, syms...)
442-
return decondition_context(
443-
ConditionContext(
444-
decondition_context(childcontext(context), syms...),
445-
BangBang.delete!!(context.values, sym),
446-
),
447-
syms...,
448-
)
449-
end
450-
451-
function decondition_context(
452-
context::NamedConditionContext, vn::VarName{sym}
453-
) where {sym}
454-
return ConditionContext(
455-
decondition_context(childcontext(context), vn), BangBang.delete!!(context.values, sym)
456-
)
457-
end
458-
function decondition_context(context::ConditionContext, vn::VarName)
459-
return ConditionContext(
460-
decondition_context(childcontext(context), vn), BangBang.delete!!(context.values, vn)
461-
)
442+
new_values = deepcopy(context.values)
443+
for s in (sym, syms...)
444+
new_values = BangBang.delete!!(new_values, s)
445+
end
446+
return if length(new_values) == 0
447+
# No more values left, can unwrap
448+
decondition_context(childcontext(context), syms...)
449+
else
450+
ConditionContext(new_values, decondition_context(childcontext(context), syms...))
451+
end
462452
end
463453

464454
"""

test/contexts.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using DynamicPPL:
1111
PointwiseLogdensityContext,
1212
contextual_isassumption,
1313
ConditionContext,
14+
decondition_context,
1415
hasconditioned,
1516
getconditioned,
1617
hasconditioned_nested,
@@ -196,6 +197,68 @@ end
196197
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
197198
end
198199

200+
@testset "ConditionContext" begin
201+
@testset "Nesting" begin
202+
@testset "NamedTuple" begin
203+
n1 = (x=1, y=2)
204+
n2 = (x=3,)
205+
# Values from outer context should override inner one
206+
ctx1 = ConditionContext(n1, ConditionContext(n2))
207+
@test ctx1.values == (x=1, y=2)
208+
# Check that the two ConditionContexts are collapsed
209+
@test childcontext(ctx1) isa DefaultContext
210+
# Then test the nesting the other way round
211+
ctx2 = ConditionContext(n2, ConditionContext(n1))
212+
@test ctx2.values == (x=3, y=2)
213+
@test childcontext(ctx2) isa DefaultContext
214+
end
215+
216+
@testset "Dict" begin
217+
# Same tests as NamedTuple above
218+
d1 = Dict(@varname(x) => 1, @varname(y) => 2)
219+
d2 = Dict(@varname(x) => 3)
220+
ctx1 = ConditionContext(d1, ConditionContext(d2))
221+
@test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2)
222+
@test childcontext(ctx1) isa DefaultContext
223+
ctx2 = ConditionContext(d2, ConditionContext(d1))
224+
@test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2)
225+
@test childcontext(ctx2) isa DefaultContext
226+
end
227+
end
228+
229+
@testset "decondition_context" begin
230+
@testset "NamedTuple" begin
231+
ctx = ConditionContext((x=1, y=2, z=3))
232+
# Decondition all variables
233+
@test decondition_context(ctx) isa DefaultContext
234+
# Decondition only some variables
235+
dctx = decondition_context(ctx, :x)
236+
@test dctx isa ConditionContext
237+
@test dctx.values == (y=2, z=3)
238+
dctx = decondition_context(ctx, :y, :z)
239+
@test dctx isa ConditionContext
240+
@test dctx.values == (x=1,)
241+
# Decondition all variables manually
242+
@test decondition_context(ctx, :x, :y, :z) isa DefaultContext
243+
end
244+
245+
@testset "Dict" begin
246+
ctx = ConditionContext(Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3))
247+
# Decondition all variables
248+
@test decondition_context(ctx) isa DefaultContext
249+
# Decondition only some variables
250+
dctx = decondition_context(ctx, @varname(x))
251+
@test dctx isa ConditionContext
252+
@test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3)
253+
dctx = decondition_context(ctx, @varname(y), @varname(z))
254+
@test dctx isa ConditionContext
255+
@test dctx.values == Dict(@varname(x) => 1)
256+
# Decondition all variables manually
257+
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa DefaultContext
258+
end
259+
end
260+
end
261+
199262
@testset "FixedContext" begin
200263
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
201264
retval = model()

0 commit comments

Comments
 (0)