@@ -11,12 +11,18 @@ using DynamicPPL:
11
11
IsParent,
12
12
PointwiseLogdensityContext,
13
13
contextual_isassumption,
14
+ FixedContext,
14
15
ConditionContext,
15
16
decondition_context,
16
17
hasconditioned,
17
18
getconditioned,
19
+ conditioned,
20
+ fixed,
18
21
hasconditioned_nested,
19
- getconditioned_nested
22
+ getconditioned_nested,
23
+ collapse_prefix_stack,
24
+ prefix_cond_and_fixed_variables,
25
+ getvalue
20
26
21
27
using EnzymeCore
22
28
@@ -306,4 +312,99 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
306
312
@test logprior (model_fixed, (; m)) == logprior (condition (model; s= s), (; m))
307
313
end
308
314
end
315
+
316
+ @testset " PrefixContext + Condition/FixedContext interactions" begin
317
+ @testset " prefix_cond_and_fixed_variables" begin
318
+ c1 = ConditionContext ((c= 1 , d= 2 ))
319
+ c1_prefixed = prefix_cond_and_fixed_variables (c1, @varname (a))
320
+ @test c1_prefixed isa ConditionContext
321
+ @test childcontext (c1_prefixed) isa DefaultContext
322
+ @test c1_prefixed. values[@varname (a. c)] == 1
323
+ @test c1_prefixed. values[@varname (a. d)] == 2
324
+
325
+ c2 = FixedContext ((f= 1 , g= 2 ))
326
+ c2_prefixed = prefix_cond_and_fixed_variables (c2, @varname (a))
327
+ @test c2_prefixed isa FixedContext
328
+ @test childcontext (c2_prefixed) isa DefaultContext
329
+ @test c2_prefixed. values[@varname (a. f)] == 1
330
+ @test c2_prefixed. values[@varname (a. g)] == 2
331
+
332
+ c3 = ConditionContext ((c= 1 , d= 2 ), FixedContext ((f= 1 , g= 2 )))
333
+ c3_prefixed = prefix_cond_and_fixed_variables (c3, @varname (a))
334
+ c3_prefixed_child = childcontext (c3_prefixed)
335
+ @test c3_prefixed isa ConditionContext
336
+ @test c3_prefixed. values[@varname (a. c)] == 1
337
+ @test c3_prefixed. values[@varname (a. d)] == 2
338
+ @test c3_prefixed_child isa FixedContext
339
+ @test c3_prefixed_child. values[@varname (a. f)] == 1
340
+ @test c3_prefixed_child. values[@varname (a. g)] == 2
341
+ @test childcontext (c3_prefixed_child) isa DefaultContext
342
+ end
343
+
344
+ @testset " collapse_prefix_stack" begin
345
+ # Utility function to make sure that there are no PrefixContexts in
346
+ # the context stack.
347
+ function has_no_prefixcontexts (ctx:: AbstractContext )
348
+ return ! (ctx isa PrefixContext) && (
349
+ NodeTrait (ctx) isa IsLeaf || has_no_prefixcontexts (childcontext (ctx))
350
+ )
351
+ end
352
+
353
+ # Prefix -> Condition
354
+ c1 = PrefixContext {:a} (ConditionContext ((c= 1 , d= 2 )))
355
+ c1 = collapse_prefix_stack (c1)
356
+ @test has_no_prefixcontexts (c1)
357
+ c1_vals = conditioned (c1)
358
+ @test length (c1_vals) == 2
359
+ @test getvalue (c1_vals, @varname (a. c)) == 1
360
+ @test getvalue (c1_vals, @varname (a. d)) == 2
361
+
362
+ # Condition -> Prefix
363
+ c2 = (ConditionContext ((c= 1 , d= 2 ), PrefixContext {:a} (DefaultContext ())))
364
+ c2 = collapse_prefix_stack (c2)
365
+ @test has_no_prefixcontexts (c2)
366
+ c2_vals = conditioned (c2)
367
+ @test length (c2_vals) == 2
368
+ @test getvalue (c2_vals, @varname (c)) == 1
369
+ @test getvalue (c2_vals, @varname (d)) == 2
370
+
371
+ # Prefix -> Fixed
372
+ c3 = PrefixContext {:a} (FixedContext ((f= 1 , g= 2 )))
373
+ c3 = collapse_prefix_stack (c3)
374
+ c3_vals = fixed (c3)
375
+ @test length (c3_vals) == 2
376
+ @test length (c3_vals) == 2
377
+ @test getvalue (c3_vals, @varname (a. f)) == 1
378
+ @test getvalue (c3_vals, @varname (a. g)) == 2
379
+
380
+ # Fixed -> Prefix
381
+ c4 = (FixedContext ((f= 1 , g= 2 ), PrefixContext {:a} (DefaultContext ())))
382
+ c4 = collapse_prefix_stack (c4)
383
+ @test has_no_prefixcontexts (c4)
384
+ c4_vals = fixed (c4)
385
+ @test length (c4_vals) == 2
386
+ @test getvalue (c4_vals, @varname (f)) == 1
387
+ @test getvalue (c4_vals, @varname (g)) == 2
388
+
389
+ # Prefix -> Condition -> Prefix -> Condition
390
+ c5 = PrefixContext {:a} (
391
+ ConditionContext ((c= 1 ,), PrefixContext {:b} (ConditionContext ((d= 2 ,))))
392
+ )
393
+ c5 = collapse_prefix_stack (c5)
394
+ @test has_no_prefixcontexts (c5)
395
+ c5_vals = conditioned (c5)
396
+ @test length (c5_vals) == 2
397
+ @test getvalue (c5_vals, @varname (a. c)) == 1
398
+ @test getvalue (c5_vals, @varname (a. b. d)) == 2
399
+
400
+ # Prefix -> Condition -> Prefix -> Fixed
401
+ c6 = PrefixContext {:a} (
402
+ ConditionContext ((c= 1 ,), PrefixContext {:b} (FixedContext ((d= 2 ,))))
403
+ )
404
+ c6 = collapse_prefix_stack (c6)
405
+ @test has_no_prefixcontexts (c6)
406
+ @test conditioned (c6) == Dict (@varname (a. c) => 1 )
407
+ @test fixed (c6) == Dict (@varname (a. b. d) => 2 )
408
+ end
409
+ end
309
410
end
0 commit comments