Skip to content

Commit 31a65a2

Browse files
committed
Add unit tests for new functions
1 parent 523f411 commit 31a65a2

File tree

1 file changed

+125
-1
lines changed

1 file changed

+125
-1
lines changed

test/contexts.jl

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@ using DynamicPPL:
1111
IsParent,
1212
PointwiseLogdensityContext,
1313
contextual_isassumption,
14+
FixedContext,
1415
ConditionContext,
1516
decondition_context,
1617
hasconditioned,
1718
getconditioned,
19+
conditioned,
20+
fixed,
1821
hasconditioned_nested,
19-
getconditioned_nested
22+
getconditioned_nested,
23+
collapse_prefix_stack,
24+
prefix_cond_and_fixed_variables,
25+
getvalue
2026

2127
using EnzymeCore
2228

@@ -156,6 +162,29 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
156162
@test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1])
157163
end
158164

165+
@testset "prefix_and_strip_contexts" begin
166+
vn = @varname(x[1])
167+
ctx1 = PrefixContext{:a}(DefaultContext())
168+
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn)
169+
@test new_vn == @varname(a.x[1])
170+
@test new_ctx == DefaultContext()
171+
172+
ctx2 = SamplingContext(PrefixContext{:a}(DefaultContext()))
173+
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn)
174+
@test new_vn == @varname(a.x[1])
175+
@test new_ctx == SamplingContext()
176+
177+
ctx3 = PrefixContext{:a}(ConditionContext((a=1,)))
178+
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn)
179+
@test new_vn == @varname(a.x[1])
180+
@test new_ctx == ConditionContext((a=1,))
181+
182+
ctx4 = SamplingContext(PrefixContext{:a}(ConditionContext((a=1,))))
183+
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn)
184+
@test new_vn == @varname(a.x[1])
185+
@test new_ctx == SamplingContext(ConditionContext((a=1,)))
186+
end
187+
159188
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
160189
prefix = :my_prefix
161190
context = DynamicPPL.PrefixContext{prefix}(SamplingContext())
@@ -306,4 +335,99 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
306335
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
307336
end
308337
end
338+
339+
@testset "PrefixContext + Condition/FixedContext interactions" begin
340+
@testset "prefix_cond_and_fixed_variables" begin
341+
c1 = ConditionContext((c=1, d=2))
342+
c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a))
343+
@test c1_prefixed isa ConditionContext
344+
@test childcontext(c1_prefixed) isa DefaultContext
345+
@test c1_prefixed.values[@varname(a.c)] == 1
346+
@test c1_prefixed.values[@varname(a.d)] == 2
347+
348+
c2 = FixedContext((f=1, g=2))
349+
c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a))
350+
@test c2_prefixed isa FixedContext
351+
@test childcontext(c2_prefixed) isa DefaultContext
352+
@test c2_prefixed.values[@varname(a.f)] == 1
353+
@test c2_prefixed.values[@varname(a.g)] == 2
354+
355+
c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2)))
356+
c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a))
357+
c3_prefixed_child = childcontext(c3_prefixed)
358+
@test c3_prefixed isa ConditionContext
359+
@test c3_prefixed.values[@varname(a.c)] == 1
360+
@test c3_prefixed.values[@varname(a.d)] == 2
361+
@test c3_prefixed_child isa FixedContext
362+
@test c3_prefixed_child.values[@varname(a.f)] == 1
363+
@test c3_prefixed_child.values[@varname(a.g)] == 2
364+
@test childcontext(c3_prefixed_child) isa DefaultContext
365+
end
366+
367+
@testset "collapse_prefix_stack" begin
368+
# Utility function to make sure that there are no PrefixContexts in
369+
# the context stack.
370+
function has_no_prefixcontexts(ctx::AbstractContext)
371+
return !(ctx isa PrefixContext) && (
372+
NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx))
373+
)
374+
end
375+
376+
# Prefix -> Condition
377+
c1 = PrefixContext{:a}(ConditionContext((c=1, d=2)))
378+
c1 = collapse_prefix_stack(c1)
379+
@test has_no_prefixcontexts(c1)
380+
c1_vals = conditioned(c1)
381+
@test length(c1_vals) == 2
382+
@test getvalue(c1_vals, @varname(a.c)) == 1
383+
@test getvalue(c1_vals, @varname(a.d)) == 2
384+
385+
# Condition -> Prefix
386+
c2 = (ConditionContext((c=1, d=2), PrefixContext{:a}(DefaultContext())))
387+
c2 = collapse_prefix_stack(c2)
388+
@test has_no_prefixcontexts(c2)
389+
c2_vals = conditioned(c2)
390+
@test length(c2_vals) == 2
391+
@test getvalue(c2_vals, @varname(c)) == 1
392+
@test getvalue(c2_vals, @varname(d)) == 2
393+
394+
# Prefix -> Fixed
395+
c3 = PrefixContext{:a}(FixedContext((f=1, g=2)))
396+
c3 = collapse_prefix_stack(c3)
397+
c3_vals = fixed(c3)
398+
@test length(c3_vals) == 2
399+
@test length(c3_vals) == 2
400+
@test getvalue(c3_vals, @varname(a.f)) == 1
401+
@test getvalue(c3_vals, @varname(a.g)) == 2
402+
403+
# Fixed -> Prefix
404+
c4 = (FixedContext((f=1, g=2), PrefixContext{:a}(DefaultContext())))
405+
c4 = collapse_prefix_stack(c4)
406+
@test has_no_prefixcontexts(c4)
407+
c4_vals = fixed(c4)
408+
@test length(c4_vals) == 2
409+
@test getvalue(c4_vals, @varname(f)) == 1
410+
@test getvalue(c4_vals, @varname(g)) == 2
411+
412+
# Prefix -> Condition -> Prefix -> Condition
413+
c5 = PrefixContext{:a}(
414+
ConditionContext((c=1,), PrefixContext{:b}(ConditionContext((d=2,))))
415+
)
416+
c5 = collapse_prefix_stack(c5)
417+
@test has_no_prefixcontexts(c5)
418+
c5_vals = conditioned(c5)
419+
@test length(c5_vals) == 2
420+
@test getvalue(c5_vals, @varname(a.c)) == 1
421+
@test getvalue(c5_vals, @varname(a.b.d)) == 2
422+
423+
# Prefix -> Condition -> Prefix -> Fixed
424+
c6 = PrefixContext{:a}(
425+
ConditionContext((c=1,), PrefixContext{:b}(FixedContext((d=2,))))
426+
)
427+
c6 = collapse_prefix_stack(c6)
428+
@test has_no_prefixcontexts(c6)
429+
@test conditioned(c6) == Dict(@varname(a.c) => 1)
430+
@test fixed(c6) == Dict(@varname(a.b.d) => 2)
431+
end
432+
end
309433
end

0 commit comments

Comments
 (0)