Skip to content

Commit 2d6de8c

Browse files
committed
Fix some tests
1 parent f84988f commit 2d6de8c

File tree

2 files changed

+43
-75
lines changed

2 files changed

+43
-75
lines changed

src/contexts.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,11 @@ function conditioned(context::ConditionContext)
519519
# precedence over decendants of `context`.
520520
return _merge(context.values, conditioned(childcontext(context)))
521521
end
522+
function conditioned(context::PrefixContext{Prefix}) where {Prefix}
523+
return conditioned(
524+
prefix_conditioned_variables(childcontext(context), VarName{Prefix}())
525+
)
526+
end
522527

523528
struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext
524529
values::Values

test/contexts.jl

Lines changed: 38 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Test, DynamicPPL, Accessors
2+
using AbstractPPL: getoptic
23
using DynamicPPL:
34
leafcontext,
45
setleafcontext,
@@ -57,7 +58,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
5758
(x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,)))
5859
),
5960
:condition3 => ConditionContext(
60-
(x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0)))
61+
(x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(y) => 2.0)))
6162
),
6263
:condition4 => ConditionContext((x=[1.0, missing],)),
6364
)
@@ -70,91 +71,53 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
7071
end
7172
end
7273

73-
@testset "contextual_isassumption" begin
74-
@testset "$(name)" for (name, context) in contexts
75-
# Any `context` should return `true` by default.
76-
@test contextual_isassumption(context, VarName{gensym(:x)}())
77-
78-
if any(Base.Fix2(isa, ConditionContext), context)
79-
# We have a `ConditionContext` among us.
80-
# Let's first extract the conditioned variables.
81-
conditioned_values = DynamicPPL.conditioned(context)
82-
83-
# The conditioned values might be a NamedTuple, or a Dict.
84-
# We convert to a Dict for consistency
85-
if conditioned_values isa NamedTuple
86-
conditioned_values = Dict(
87-
VarName{sym}() => val for (sym, val) in pairs(conditioned_values)
88-
)
89-
end
90-
91-
for (vn, val) in pairs(conditioned_values)
92-
# We need to drop the prefix of `var` since in `contextual_isassumption`
93-
# it will be threaded through the `PrefixContext` before it reaches
94-
# `ConditionContext` with the conditioned variable.
95-
vn_without_prefix = if getoptic(vn) isa PropertyLens
96-
# Hacky: This assumes that there is exactly one level of prefixing
97-
# that we need to undo. This is appropriate for the :condition3
98-
# test case above, but is not generally correct.
99-
AbstractPPL.unprefix(vn, VarName{getsym(vn)}())
100-
else
101-
vn
102-
end
103-
104-
@show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
105-
# Let's check elementwise.
106-
for vn_child in
107-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
108-
if getoptic(vn_child)(val) === missing
109-
@test contextual_isassumption(context, vn_child)
110-
else
111-
@test !contextual_isassumption(context, vn_child)
112-
end
113-
end
114-
end
115-
end
116-
end
117-
end
74+
@testset "extracting conditioned values" begin
75+
# This testset tests `contextual_isassumption`, `getconditioned_nested`, and
76+
# `hasconditioned_nested`.
11877

119-
@testset "getconditioned_nested & hasconditioned_nested" begin
120-
@testset "$name" for (name, context) in contexts
78+
@testset "$(name)" for (name, context) in contexts
79+
# If the varname doesn't exist, it should always be an assumption.
12180
fake_vn = VarName{gensym(:x)}()
81+
@test contextual_isassumption(context, fake_vn)
12282
@test !hasconditioned_nested(context, fake_vn)
12383
@test_throws ErrorException getconditioned_nested(context, fake_vn)
12484

12585
if any(Base.Fix2(isa, ConditionContext), context)
126-
# `ConditionContext` specific.
127-
86+
# We have a `ConditionContext` among us.
12887
# Let's first extract the conditioned variables.
12988
conditioned_values = DynamicPPL.conditioned(context)
89+
13090
# The conditioned values might be a NamedTuple, or a Dict.
13191
# We convert to a Dict for consistency
132-
if conditioned_values isa NamedTuple
133-
conditioned_values = Dict(
134-
VarName{sym}() => val for (sym, val) in pairs(conditioned_values)
135-
)
136-
end
137-
138-
for (vn, val) in pairs(conditioned_values)
139-
# We need to drop the prefix of `var` since in `contextual_isassumption`
140-
# it will be threaded through the `PrefixContext` before it reaches
141-
# `ConditionContext` with the conditioned variable.
142-
vn_without_prefix = if getoptic(vn) isa PropertyLens
143-
# Hacky: This assumes that there is exactly one level of prefixing
144-
# that we need to undo. This is appropriate for the :condition3
145-
# test case above, but is not generally correct.
146-
AbstractPPL.unprefix(vn, VarName{getsym(vn)}())
92+
conditioned_values = DynamicPPL.to_varname_dict(conditioned_values)
93+
94+
# Extract all conditioned variables. We also use varname_leaves
95+
# here to split up arrays which could potentially have some,
96+
# but not all, elements being `missing`.
97+
conditioned_vns = mapreduce(
98+
p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second),
99+
vcat,
100+
pairs(conditioned_values),
101+
)
102+
@show conditioned_vns
103+
104+
# We can now loop over them to check which ones are missing. We use
105+
# `getvalue` to handle the awkward case where sometimes
106+
# `conditioned_values` contains the full Varname (e.g. `a.x`) and
107+
# sometimes only the main symbol (e.g. it contains `x` when
108+
# `vn` is `x[1]`)
109+
for vn in conditioned_vns
110+
val = DynamicPPL.getvalue(conditioned_values, vn)
111+
# These VarNames are present in the conditioning values, so
112+
# we should always be able to extract the value.
113+
@test hasconditioned_nested(context, vn)
114+
@test getconditioned_nested(context, vn) === val
115+
# However, the return value of contextual_isassumption depends on
116+
# whether the value is missing or not.
117+
if ismissing(val)
118+
@test contextual_isassumption(context, vn)
147119
else
148-
vn
149-
end
150-
151-
for vn_child in
152-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
153-
# `vn_child` should be in `context`.
154-
@test hasconditioned_nested(context, vn_child)
155-
# Value should be the same as extracted above.
156-
@test getconditioned_nested(context, vn_child) ===
157-
getoptic(vn_child)(val)
120+
@test !contextual_isassumption(context, vn)
158121
end
159122
end
160123
end

0 commit comments

Comments
 (0)