Skip to content

Commit 32437fb

Browse files
committed
Fix some tests
1 parent f84988f commit 32437fb

File tree

3 files changed

+63
-90
lines changed

3 files changed

+63
-90
lines changed

src/contexts.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,11 @@ function conditioned(context::ConditionContext)
519519
# precedence over decendants of `context`.
520520
return _merge(context.values, conditioned(childcontext(context)))
521521
end
522+
function conditioned(context::PrefixContext{Prefix}) where {Prefix}
523+
return conditioned(
524+
prefix_conditioned_variables(childcontext(context), VarName{Prefix}())
525+
)
526+
end
522527

523528
struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext
524529
values::Values

src/model.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,29 +425,34 @@ julia> # Returns all the variables we have conditioned on + their values.
425425
conditioned(condition(m, x=100.0, m=1.0))
426426
(x = 100.0, m = 1.0)
427427
428-
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
428+
julia> # Nested ones also work.
429+
# (Note that `PrefixContext` also prefixes the variables of any
430+
# ConditionContext that is _inside_ it; because of this, the type of the
431+
# container has to be broadened to a `Dict`.)
429432
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
430433
431434
julia> conditioned(cm)
432-
(x = 100.0, m = 1.0)
435+
Dict{VarName, Any} with 2 entries:
436+
a.m => 1.0
437+
x => 100.0
433438
434-
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
435-
# `a.m` is treated as a random variable.
439+
julia> # Since we conditioned on `a.m`, it is not treated as a random variable.
440+
# However, `a.x` will still be a random variable.
436441
keys(VarInfo(cm))
437-
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
438-
a.m
442+
1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}:
443+
a.x
439444
440-
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
441-
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0);
445+
julia> # We can also condition on `a.m` _outside_ of the PrefixContext:
446+
cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0));
442447
443-
julia> conditioned(cm)[@varname(x)]
444-
100.0
445-
446-
julia> conditioned(cm)[@varname(a.m)]
447-
1.0
448+
julia> conditioned(cm)
449+
Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry:
450+
a.m => 1.0
448451
449-
julia> keys(VarInfo(cm)) # No variables are sampled
450-
VarName[]
452+
julia> # Now `a.x` will be sampled.
453+
keys(VarInfo(cm))
454+
1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}:
455+
a.x
451456
```
452457
"""
453458
conditioned(model::Model) = conditioned(model.context)

test/contexts.jl

Lines changed: 38 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Test, DynamicPPL, Accessors
2+
using AbstractPPL: getoptic
23
using DynamicPPL:
34
leafcontext,
45
setleafcontext,
@@ -57,7 +58,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
5758
(x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,)))
5859
),
5960
:condition3 => ConditionContext(
60-
(x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0)))
61+
(x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(y) => 2.0)))
6162
),
6263
:condition4 => ConditionContext((x=[1.0, missing],)),
6364
)
@@ -70,91 +71,53 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
7071
end
7172
end
7273

73-
@testset "contextual_isassumption" begin
74-
@testset "$(name)" for (name, context) in contexts
75-
# Any `context` should return `true` by default.
76-
@test contextual_isassumption(context, VarName{gensym(:x)}())
77-
78-
if any(Base.Fix2(isa, ConditionContext), context)
79-
# We have a `ConditionContext` among us.
80-
# Let's first extract the conditioned variables.
81-
conditioned_values = DynamicPPL.conditioned(context)
82-
83-
# The conditioned values might be a NamedTuple, or a Dict.
84-
# We convert to a Dict for consistency
85-
if conditioned_values isa NamedTuple
86-
conditioned_values = Dict(
87-
VarName{sym}() => val for (sym, val) in pairs(conditioned_values)
88-
)
89-
end
90-
91-
for (vn, val) in pairs(conditioned_values)
92-
# We need to drop the prefix of `var` since in `contextual_isassumption`
93-
# it will be threaded through the `PrefixContext` before it reaches
94-
# `ConditionContext` with the conditioned variable.
95-
vn_without_prefix = if getoptic(vn) isa PropertyLens
96-
# Hacky: This assumes that there is exactly one level of prefixing
97-
# that we need to undo. This is appropriate for the :condition3
98-
# test case above, but is not generally correct.
99-
AbstractPPL.unprefix(vn, VarName{getsym(vn)}())
100-
else
101-
vn
102-
end
103-
104-
@show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
105-
# Let's check elementwise.
106-
for vn_child in
107-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
108-
if getoptic(vn_child)(val) === missing
109-
@test contextual_isassumption(context, vn_child)
110-
else
111-
@test !contextual_isassumption(context, vn_child)
112-
end
113-
end
114-
end
115-
end
116-
end
117-
end
74+
@testset "extracting conditioned values" begin
75+
# This testset tests `contextual_isassumption`, `getconditioned_nested`, and
76+
# `hasconditioned_nested`.
11877

119-
@testset "getconditioned_nested & hasconditioned_nested" begin
120-
@testset "$name" for (name, context) in contexts
78+
@testset "$(name)" for (name, context) in contexts
79+
# If the varname doesn't exist, it should always be an assumption.
12180
fake_vn = VarName{gensym(:x)}()
81+
@test contextual_isassumption(context, fake_vn)
12282
@test !hasconditioned_nested(context, fake_vn)
12383
@test_throws ErrorException getconditioned_nested(context, fake_vn)
12484

12585
if any(Base.Fix2(isa, ConditionContext), context)
126-
# `ConditionContext` specific.
127-
86+
# We have a `ConditionContext` among us.
12887
# Let's first extract the conditioned variables.
12988
conditioned_values = DynamicPPL.conditioned(context)
89+
13090
# The conditioned values might be a NamedTuple, or a Dict.
13191
# We convert to a Dict for consistency
132-
if conditioned_values isa NamedTuple
133-
conditioned_values = Dict(
134-
VarName{sym}() => val for (sym, val) in pairs(conditioned_values)
135-
)
136-
end
137-
138-
for (vn, val) in pairs(conditioned_values)
139-
# We need to drop the prefix of `var` since in `contextual_isassumption`
140-
# it will be threaded through the `PrefixContext` before it reaches
141-
# `ConditionContext` with the conditioned variable.
142-
vn_without_prefix = if getoptic(vn) isa PropertyLens
143-
# Hacky: This assumes that there is exactly one level of prefixing
144-
# that we need to undo. This is appropriate for the :condition3
145-
# test case above, but is not generally correct.
146-
AbstractPPL.unprefix(vn, VarName{getsym(vn)}())
92+
conditioned_values = DynamicPPL.to_varname_dict(conditioned_values)
93+
94+
# Extract all conditioned variables. We also use varname_leaves
95+
# here to split up arrays which could potentially have some,
96+
# but not all, elements being `missing`.
97+
conditioned_vns = mapreduce(
98+
p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second),
99+
vcat,
100+
pairs(conditioned_values),
101+
)
102+
@show conditioned_vns
103+
104+
# We can now loop over them to check which ones are missing. We use
105+
# `getvalue` to handle the awkward case where sometimes
106+
# `conditioned_values` contains the full Varname (e.g. `a.x`) and
107+
# sometimes only the main symbol (e.g. it contains `x` when
108+
# `vn` is `x[1]`)
109+
for vn in conditioned_vns
110+
val = DynamicPPL.getvalue(conditioned_values, vn)
111+
# These VarNames are present in the conditioning values, so
112+
# we should always be able to extract the value.
113+
@test hasconditioned_nested(context, vn)
114+
@test getconditioned_nested(context, vn) === val
115+
# However, the return value of contextual_isassumption depends on
116+
# whether the value is missing or not.
117+
if ismissing(val)
118+
@test contextual_isassumption(context, vn)
147119
else
148-
vn
149-
end
150-
151-
for vn_child in
152-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
153-
# `vn_child` should be in `context`.
154-
@test hasconditioned_nested(context, vn_child)
155-
# Value should be the same as extracted above.
156-
@test getconditioned_nested(context, vn_child) ===
157-
getoptic(vn_child)(val)
120+
@test !contextual_isassumption(context, vn)
158121
end
159122
end
160123
end

0 commit comments

Comments
 (0)