@@ -11,12 +11,18 @@ using DynamicPPL:
11
11
IsParent,
12
12
PointwiseLogdensityContext,
13
13
contextual_isassumption,
14
+ FixedContext,
14
15
ConditionContext,
15
16
decondition_context,
16
17
hasconditioned,
17
18
getconditioned,
19
+ conditioned,
20
+ fixed,
18
21
hasconditioned_nested,
19
- getconditioned_nested
22
+ getconditioned_nested,
23
+ collapse_prefix_stack,
24
+ prefix_cond_and_fixed_variables,
25
+ getvalue
20
26
21
27
using EnzymeCore
22
28
@@ -156,6 +162,29 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
156
162
@test DynamicPPL. prefix (ctx4, vn) == @varname (b. a. x[1 ])
157
163
end
158
164
165
+ @testset " prefix_and_strip_contexts" begin
166
+ vn = @varname (x[1 ])
167
+ ctx1 = PrefixContext {:a} (DefaultContext ())
168
+ new_vn, new_ctx = DynamicPPL. prefix_and_strip_contexts (ctx1, vn)
169
+ @test new_vn == @varname (a. x[1 ])
170
+ @test new_ctx == DefaultContext ()
171
+
172
+ ctx2 = SamplingContext (PrefixContext {:a} (DefaultContext ()))
173
+ new_vn, new_ctx = DynamicPPL. prefix_and_strip_contexts (ctx2, vn)
174
+ @test new_vn == @varname (a. x[1 ])
175
+ @test new_ctx == SamplingContext ()
176
+
177
+ ctx3 = PrefixContext {:a} (ConditionContext ((a= 1 ,)))
178
+ new_vn, new_ctx = DynamicPPL. prefix_and_strip_contexts (ctx3, vn)
179
+ @test new_vn == @varname (a. x[1 ])
180
+ @test new_ctx == ConditionContext ((a= 1 ,))
181
+
182
+ ctx4 = SamplingContext (PrefixContext {:a} (ConditionContext ((a= 1 ,))))
183
+ new_vn, new_ctx = DynamicPPL. prefix_and_strip_contexts (ctx4, vn)
184
+ @test new_vn == @varname (a. x[1 ])
185
+ @test new_ctx == SamplingContext (ConditionContext ((a= 1 ,)))
186
+ end
187
+
159
188
@testset " evaluation: $(model. f) " for model in DynamicPPL. TestUtils. DEMO_MODELS
160
189
prefix = :my_prefix
161
190
context = DynamicPPL. PrefixContext {prefix} (SamplingContext ())
@@ -306,4 +335,99 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
306
335
@test logprior (model_fixed, (; m)) == logprior (condition (model; s= s), (; m))
307
336
end
308
337
end
338
+
339
+ @testset " PrefixContext + Condition/FixedContext interactions" begin
340
+ @testset " prefix_cond_and_fixed_variables" begin
341
+ c1 = ConditionContext ((c= 1 , d= 2 ))
342
+ c1_prefixed = prefix_cond_and_fixed_variables (c1, @varname (a))
343
+ @test c1_prefixed isa ConditionContext
344
+ @test childcontext (c1_prefixed) isa DefaultContext
345
+ @test c1_prefixed. values[@varname (a. c)] == 1
346
+ @test c1_prefixed. values[@varname (a. d)] == 2
347
+
348
+ c2 = FixedContext ((f= 1 , g= 2 ))
349
+ c2_prefixed = prefix_cond_and_fixed_variables (c2, @varname (a))
350
+ @test c2_prefixed isa FixedContext
351
+ @test childcontext (c2_prefixed) isa DefaultContext
352
+ @test c2_prefixed. values[@varname (a. f)] == 1
353
+ @test c2_prefixed. values[@varname (a. g)] == 2
354
+
355
+ c3 = ConditionContext ((c= 1 , d= 2 ), FixedContext ((f= 1 , g= 2 )))
356
+ c3_prefixed = prefix_cond_and_fixed_variables (c3, @varname (a))
357
+ c3_prefixed_child = childcontext (c3_prefixed)
358
+ @test c3_prefixed isa ConditionContext
359
+ @test c3_prefixed. values[@varname (a. c)] == 1
360
+ @test c3_prefixed. values[@varname (a. d)] == 2
361
+ @test c3_prefixed_child isa FixedContext
362
+ @test c3_prefixed_child. values[@varname (a. f)] == 1
363
+ @test c3_prefixed_child. values[@varname (a. g)] == 2
364
+ @test childcontext (c3_prefixed_child) isa DefaultContext
365
+ end
366
+
367
+ @testset " collapse_prefix_stack" begin
368
+ # Utility function to make sure that there are no PrefixContexts in
369
+ # the context stack.
370
+ function has_no_prefixcontexts (ctx:: AbstractContext )
371
+ return ! (ctx isa PrefixContext) && (
372
+ NodeTrait (ctx) isa IsLeaf || has_no_prefixcontexts (childcontext (ctx))
373
+ )
374
+ end
375
+
376
+ # Prefix -> Condition
377
+ c1 = PrefixContext {:a} (ConditionContext ((c= 1 , d= 2 )))
378
+ c1 = collapse_prefix_stack (c1)
379
+ @test has_no_prefixcontexts (c1)
380
+ c1_vals = conditioned (c1)
381
+ @test length (c1_vals) == 2
382
+ @test getvalue (c1_vals, @varname (a. c)) == 1
383
+ @test getvalue (c1_vals, @varname (a. d)) == 2
384
+
385
+ # Condition -> Prefix
386
+ c2 = (ConditionContext ((c= 1 , d= 2 ), PrefixContext {:a} (DefaultContext ())))
387
+ c2 = collapse_prefix_stack (c2)
388
+ @test has_no_prefixcontexts (c2)
389
+ c2_vals = conditioned (c2)
390
+ @test length (c2_vals) == 2
391
+ @test getvalue (c2_vals, @varname (c)) == 1
392
+ @test getvalue (c2_vals, @varname (d)) == 2
393
+
394
+ # Prefix -> Fixed
395
+ c3 = PrefixContext {:a} (FixedContext ((f= 1 , g= 2 )))
396
+ c3 = collapse_prefix_stack (c3)
397
+ c3_vals = fixed (c3)
398
+ @test length (c3_vals) == 2
399
+ @test length (c3_vals) == 2
400
+ @test getvalue (c3_vals, @varname (a. f)) == 1
401
+ @test getvalue (c3_vals, @varname (a. g)) == 2
402
+
403
+ # Fixed -> Prefix
404
+ c4 = (FixedContext ((f= 1 , g= 2 ), PrefixContext {:a} (DefaultContext ())))
405
+ c4 = collapse_prefix_stack (c4)
406
+ @test has_no_prefixcontexts (c4)
407
+ c4_vals = fixed (c4)
408
+ @test length (c4_vals) == 2
409
+ @test getvalue (c4_vals, @varname (f)) == 1
410
+ @test getvalue (c4_vals, @varname (g)) == 2
411
+
412
+ # Prefix -> Condition -> Prefix -> Condition
413
+ c5 = PrefixContext {:a} (
414
+ ConditionContext ((c= 1 ,), PrefixContext {:b} (ConditionContext ((d= 2 ,))))
415
+ )
416
+ c5 = collapse_prefix_stack (c5)
417
+ @test has_no_prefixcontexts (c5)
418
+ c5_vals = conditioned (c5)
419
+ @test length (c5_vals) == 2
420
+ @test getvalue (c5_vals, @varname (a. c)) == 1
421
+ @test getvalue (c5_vals, @varname (a. b. d)) == 2
422
+
423
+ # Prefix -> Condition -> Prefix -> Fixed
424
+ c6 = PrefixContext {:a} (
425
+ ConditionContext ((c= 1 ,), PrefixContext {:b} (FixedContext ((d= 2 ,))))
426
+ )
427
+ c6 = collapse_prefix_stack (c6)
428
+ @test has_no_prefixcontexts (c6)
429
+ @test conditioned (c6) == Dict (@varname (a. c) => 1 )
430
+ @test fixed (c6) == Dict (@varname (a. b. d) => 2 )
431
+ end
432
+ end
309
433
end
0 commit comments