Skip to content

Commit ea60973

Browse files
committed
Add unit tests for new functions
1 parent 523f411 commit ea60973

File tree

1 file changed

+102
-1
lines changed

1 file changed

+102
-1
lines changed

test/contexts.jl

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@ using DynamicPPL:
1111
IsParent,
1212
PointwiseLogdensityContext,
1313
contextual_isassumption,
14+
FixedContext,
1415
ConditionContext,
1516
decondition_context,
1617
hasconditioned,
1718
getconditioned,
19+
conditioned,
20+
fixed,
1821
hasconditioned_nested,
19-
getconditioned_nested
22+
getconditioned_nested,
23+
collapse_prefix_stack,
24+
prefix_cond_and_fixed_variables,
25+
getvalue
2026

2127
using EnzymeCore
2228

@@ -306,4 +312,99 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
306312
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
307313
end
308314
end
315+
316+
@testset "PrefixContext + Condition/FixedContext interactions" begin
317+
@testset "prefix_cond_and_fixed_variables" begin
318+
c1 = ConditionContext((c=1, d=2))
319+
c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a))
320+
@test c1_prefixed isa ConditionContext
321+
@test childcontext(c1_prefixed) isa DefaultContext
322+
@test c1_prefixed.values[@varname(a.c)] == 1
323+
@test c1_prefixed.values[@varname(a.d)] == 2
324+
325+
c2 = FixedContext((f=1, g=2))
326+
c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a))
327+
@test c2_prefixed isa FixedContext
328+
@test childcontext(c2_prefixed) isa DefaultContext
329+
@test c2_prefixed.values[@varname(a.f)] == 1
330+
@test c2_prefixed.values[@varname(a.g)] == 2
331+
332+
c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2)))
333+
c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a))
334+
c3_prefixed_child = childcontext(c3_prefixed)
335+
@test c3_prefixed isa ConditionContext
336+
@test c3_prefixed.values[@varname(a.c)] == 1
337+
@test c3_prefixed.values[@varname(a.d)] == 2
338+
@test c3_prefixed_child isa FixedContext
339+
@test c3_prefixed_child.values[@varname(a.f)] == 1
340+
@test c3_prefixed_child.values[@varname(a.g)] == 2
341+
@test childcontext(c3_prefixed_child) isa DefaultContext
342+
end
343+
344+
@testset "collapse_prefix_stack" begin
345+
# Utility function to make sure that there are no PrefixContexts in
346+
# the context stack.
347+
function has_no_prefixcontexts(ctx::AbstractContext)
348+
return !(ctx isa PrefixContext) && (
349+
NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx))
350+
)
351+
end
352+
353+
# Prefix -> Condition
354+
c1 = PrefixContext{:a}(ConditionContext((c=1, d=2)))
355+
c1 = collapse_prefix_stack(c1)
356+
@test has_no_prefixcontexts(c1)
357+
c1_vals = conditioned(c1)
358+
@test length(c1_vals) == 2
359+
@test getvalue(c1_vals, @varname(a.c)) == 1
360+
@test getvalue(c1_vals, @varname(a.d)) == 2
361+
362+
# Condition -> Prefix
363+
c2 = (ConditionContext((c=1, d=2), PrefixContext{:a}(DefaultContext())))
364+
c2 = collapse_prefix_stack(c2)
365+
@test has_no_prefixcontexts(c2)
366+
c2_vals = conditioned(c2)
367+
@test length(c2_vals) == 2
368+
@test getvalue(c2_vals, @varname(c)) == 1
369+
@test getvalue(c2_vals, @varname(d)) == 2
370+
371+
# Prefix -> Fixed
372+
c3 = PrefixContext{:a}(FixedContext((f=1, g=2)))
373+
c3 = collapse_prefix_stack(c3)
374+
c3_vals = fixed(c3)
375+
@test length(c3_vals) == 2
376+
@test length(c3_vals) == 2
377+
@test getvalue(c3_vals, @varname(a.f)) == 1
378+
@test getvalue(c3_vals, @varname(a.g)) == 2
379+
380+
# Fixed -> Prefix
381+
c4 = (FixedContext((f=1, g=2), PrefixContext{:a}(DefaultContext())))
382+
c4 = collapse_prefix_stack(c4)
383+
@test has_no_prefixcontexts(c4)
384+
c4_vals = fixed(c4)
385+
@test length(c4_vals) == 2
386+
@test getvalue(c4_vals, @varname(f)) == 1
387+
@test getvalue(c4_vals, @varname(g)) == 2
388+
389+
# Prefix -> Condition -> Prefix -> Condition
390+
c5 = PrefixContext{:a}(
391+
ConditionContext((c=1,), PrefixContext{:b}(ConditionContext((d=2,))))
392+
)
393+
c5 = collapse_prefix_stack(c5)
394+
@test has_no_prefixcontexts(c5)
395+
c5_vals = conditioned(c5)
396+
@test length(c5_vals) == 2
397+
@test getvalue(c5_vals, @varname(a.c)) == 1
398+
@test getvalue(c5_vals, @varname(a.b.d)) == 2
399+
400+
# Prefix -> Condition -> Prefix -> Fixed
401+
c6 = PrefixContext{:a}(
402+
ConditionContext((c=1,), PrefixContext{:b}(FixedContext((d=2,))))
403+
)
404+
c6 = collapse_prefix_stack(c6)
405+
@test has_no_prefixcontexts(c6)
406+
@test conditioned(c6) == Dict(@varname(a.c) => 1)
407+
@test fixed(c6) == Dict(@varname(a.b.d) => 2)
408+
end
409+
end
309410
end

0 commit comments

Comments
 (0)