Skip to content

Commit 14ea3b7

Browse files
committed
Add tests; fix a bunch of nested submodel issues
1 parent 8ab47f7 commit 14ea3b7

File tree

4 files changed

+183
-8
lines changed

4 files changed

+183
-8
lines changed

src/context_implementations.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ function tilde_assume(context::AbstractContext, args...)
5757
return tilde_assume(NodeTrait(tilde_assume, context), context, args...)
5858
end
5959
function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi)
60+
@show "isleaf", vn
6061
return assume(right, vn, vi)
6162
end
6263
function tilde_assume(::IsParent, context::AbstractContext, args...)
@@ -85,12 +86,25 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig
8586
end
8687

8788
function tilde_assume(context::PrefixContext, right, vn, vi)
88-
return tilde_assume(context.context, right, prefix(context, vn), vi)
89+
# The slightly tricky thing about PrefixContext is that they are applied
90+
# from the outside in, so `PrefixContext{:a}(PrefixContext{:b}(ctx))` means
91+
# that variables get prefixed like `a.b.x`.
92+
# This motivates the implementation shown here, where the function
93+
# `prefix_and_strip_contexts` is responsible for not only adding the
94+
# prefixes, but also removing the `PrefixContext`s from the context stack
95+
# so that they don't get applied twice when recursing.
96+
# TODO(penelopeysm): It would be nice to switch this round, but it's a very
97+
# tricky task. Essentially it forces us to use a foldr inside
98+
# `prefix_and_strip_contexts`, rather than a foldl which is what most of
99+
# DynamicPPL uses.
100+
new_vn, new_context = prefix_and_strip_contexts(context, vn)
101+
return tilde_assume(new_context, right, new_vn, vi)
89102
end
90103
function tilde_assume(
91104
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
92105
)
93-
return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi)
106+
new_vn, new_context = prefix_and_strip_contexts(context, vn)
107+
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
94108
end
95109

96110
"""

src/contexts.jl

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,30 @@ function prefix(::IsParent, ctx::AbstractContext, vn::VarName)
276276
return prefix(childcontext(ctx), vn)
277277
end
278278

279+
"""
280+
prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName)
281+
282+
Same as `prefix`, but additionally returns a new context stack that has all the
283+
PrefixContexts removed.
284+
"""
285+
function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
286+
child_context = childcontext(ctx)
287+
# vn_prefixed contains the prefixes from all lower levels
288+
vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts(
289+
child_context, vn
290+
)
291+
return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()),
292+
child_context_without_prefixes
293+
end
294+
function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName)
295+
return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn)
296+
end
297+
prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx)
298+
function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName)
299+
vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn)
300+
return vn, setchildcontext(ctx, new_ctx)
301+
end
302+
279303
"""
280304
prefix(model::Model, x)
281305
@@ -351,6 +375,29 @@ NodeTrait(::ConditionContext) = IsParent()
351375
childcontext(context::ConditionContext) = context.context
352376
setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child)
353377

378+
"""
379+
collapse_prefix_and_condition(context::AbstractContext)
380+
381+
Apply `PrefixContext`s to any conditioned values inside them, and remove
382+
the `PrefixContext`s from the context stack.
383+
384+
```jldoctest
385+
julia> using DynamicPPL: collapse_prefix_and_condition
386+
387+
julia> c1 = PrefixContext({:a}(ConditionContext((x=1, )))
388+
```
389+
"""
390+
function collapse_prefix_and_condition(context::PrefixContext{Prefix}) where {Prefix}
391+
# Collapse the child context (thus applying any inner prefixes first)
392+
collapsed = collapse_prefix_and_condition(childcontext(context))
393+
# Prefix any conditioned variables with the current prefix
394+
# Note: prefix_conditioned_variables is O(N) in the depth of the context stack.
395+
# So is this function. In the worst case scenario, this is O(N^2) in the
396+
# depth of the context stack.
397+
return prefix_conditioned_variables(collapsed, VarName{Prefix}())
398+
end
399+
collapse_prefix_and_condition(context::AbstractContext) = context
400+
354401
"""
355402
prefix_conditioned_variables(context::AbstractContext, prefix::VarName)
356403
@@ -427,9 +474,7 @@ function hasconditioned_nested(::IsParent, context, vn)
427474
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
428475
end
429476
function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
430-
return hasconditioned_nested(
431-
prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn
432-
)
477+
return hasconditioned_nested(collapse_prefix_and_condition(context), vn)
433478
end
434479

435480
"""
@@ -447,9 +492,7 @@ function getconditioned_nested(::IsLeaf, context, vn)
447492
return error("context $(context) does not contain value for $vn")
448493
end
449494
function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
450-
return getconditioned_nested(
451-
prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn
452-
)
495+
return getconditioned_nested(collapse_prefix_and_condition(context), vn)
453496
end
454497
function getconditioned_nested(::IsParent, context, vn)
455498
return if hasconditioned(context, vn)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ include("test_util.jl")
6767
include("threadsafe.jl")
6868
include("debug_utils.jl")
6969
include("deprecated.jl")
70+
include("submodels.jl")
7071
end
7172

7273
if GROUP == "All" || GROUP == "Group2"

test/submodels.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
module DPPLSubmodelTests
2+
3+
using DynamicPPL
4+
using Distributions
5+
using Test
6+
7+
@testset "submodels.jl" begin
8+
@testset "Conditioning variables" begin
9+
@testset "Auto prefix" begin
10+
@model function inner()
11+
x ~ Normal()
12+
return y ~ Normal()
13+
end
14+
@model function outer()
15+
return a ~ to_submodel(inner())
16+
end
17+
inner_cond = inner() | (@varname(x) => 1.0)
18+
with_outer_cond = outer() | (@varname(a.x) => 1.0)
19+
20+
# No conditioning
21+
@test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)])
22+
# Conditioning from the outside
23+
@test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(a.y)])
24+
# Conditioning from the inside
25+
@model function outer2()
26+
return a ~ to_submodel(inner_cond)
27+
end
28+
with_inner_cond = outer2()
29+
@test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(a.y)])
30+
end
31+
32+
@testset "No prefix" begin
33+
@model function inner()
34+
x ~ Normal()
35+
return y ~ Normal()
36+
end
37+
@model function outer()
38+
return a ~ to_submodel(inner(), false)
39+
end
40+
inner_cond = inner() | (@varname(x) => 1.0)
41+
with_outer_cond = outer() | (@varname(x) => 1.0)
42+
43+
# No conditioning
44+
@test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)])
45+
# Conditioning from the outside
46+
@test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(y)])
47+
# Conditioning from the inside
48+
@model function outer2()
49+
return a ~ to_submodel(inner_cond, false)
50+
end
51+
with_inner_cond = outer2()
52+
@test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(y)])
53+
end
54+
55+
@testset "Manual prefix" begin
56+
@model function inner()
57+
x ~ Normal()
58+
return y ~ Normal()
59+
end
60+
@model function outer()
61+
return a ~ to_submodel(prefix(inner(), :b), false)
62+
end
63+
inner_cond = inner() | (@varname(x) => 1.0)
64+
with_outer_cond = outer() | (@varname(b.x) => 1.0)
65+
66+
# No conditioning
67+
@test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)])
68+
# Conditioning from the outside
69+
@test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(b.y)])
70+
# Conditioning from the inside
71+
@model function outer2()
72+
return a ~ to_submodel(prefix(inner_cond, :b), false)
73+
end
74+
with_inner_cond = outer2()
75+
@test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(b.y)])
76+
end
77+
78+
@testset "Nested submodels" begin
79+
@model function f()
80+
x ~ Normal()
81+
return y ~ Normal()
82+
end
83+
@model function g()
84+
return _unused ~ to_submodel(prefix(f(), :b), false)
85+
end
86+
@model function h()
87+
return a ~ to_submodel(g())
88+
end
89+
90+
# No conditioning
91+
@test Set(keys(VarInfo(h()))) == Set([@varname(a.b.x), @varname(a.b.y)])
92+
93+
# Conditioning at the top level
94+
condition_h = h() | (@varname(a.b.x) => 1.0)
95+
@test Set(keys(VarInfo(condition_h))) == Set([@varname(a.b.y)])
96+
97+
# Conditioning at the second level
98+
condition_g = g() | (@varname(b.x) => 1.0)
99+
@model function h2()
100+
return a ~ to_submodel(condition_g)
101+
end
102+
@test Set(keys(VarInfo(h2()))) == Set([@varname(a.b.y)])
103+
104+
# Conditioning at the very bottom
105+
condition_f = f() | (@varname(x) => 1.0)
106+
@model function g2()
107+
return _unused ~ to_submodel(prefix(condition_f, :b), false)
108+
end
109+
@model function h3()
110+
return a ~ to_submodel(g2())
111+
end
112+
@test Set(keys(VarInfo(h3()))) == Set([@varname(a.b.y)])
113+
end
114+
end
115+
end
116+
117+
end

0 commit comments

Comments
 (0)