Skip to content

Commit cb3c91b

Browse files
committed
Fix conditioning in submodels
1 parent be27636 commit cb3c91b

File tree

4 files changed

+94
-20
lines changed

4 files changed

+94
-20
lines changed

src/compiler.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ function isassumption(
5353
vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))),
5454
)
5555
return quote
56-
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
56+
if $(DynamicPPL.contextual_isassumption)(
57+
__context__, $(DynamicPPL.prefix)(__context__, $vn)
58+
)
5759
# Considered an assumption by `__context__` which means either:
5860
# 1. We hit the default implementation, e.g. using `DefaultContext`,
5961
# which in turn means that we haven't considered if it's one of
@@ -112,8 +114,10 @@ function contextual_isassumption(context::ConditionContext, vn)
112114
# so we defer to `childcontext` if we haven't concluded that anything yet.
113115
return contextual_isassumption(childcontext(context), vn)
114116
end
115-
function contextual_isassumption(context::PrefixContext, vn)
116-
return contextual_isassumption(childcontext(context), prefix(context, vn))
117+
function contextual_isassumption(context::PrefixContext{Prefix}, vn) where {Prefix}
118+
return contextual_isassumption(
119+
prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn
120+
)
117121
end
118122

119123
isfixed(expr, vn) = false
@@ -473,7 +477,9 @@ function generate_tilde(left, right)
473477
else
474478
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
475479
if !$(DynamicPPL.inargnames)($vn, __model__)
476-
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
480+
$left = $(DynamicPPL.getconditioned_nested)(
481+
__context__, $(DynamicPPL.prefix)(__context__, $vn)
482+
)
477483
end
478484

479485
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(

src/context_implementations.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,34 @@ probability of `vi` with the returned value.
104104
"""
105105
function tilde_assume!!(context, right, vn, vi)
106106
return if is_rhs_model(right)
107-
# Prefix the variables using the `vn`.
108-
rand_like!!(
109-
right,
110-
should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context,
111-
vi,
112-
)
107+
# We want to merge the context inside `right`, as well as the context passed
108+
# down from the parent model. To ensure that _only_ variables inside
109+
# right.model.context are prefixed, we need to make sure to _only_ apply
110+
# PrefixContext to it, before wrapping the whole prefixed context in the
111+
# parent context.
112+
# NOTE: This relies on the existence of `right.model.model`. Right now,
113+
# the only thing that can return true for `is_rhs_model` is something
114+
# (a `Sampleable`) that has a `model` field that itself (a
115+
# `ReturnedModelWrapper`) has a `model` field. This may or may not
116+
# change in the future.
117+
if should_auto_prefix(right)
118+
dppl_model = right.model.model # This isa DynamicPPL.Model
119+
prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context)
120+
# Having done this, we can set the submodel context to its leaf context.
121+
new_dppl_model = contextualize(dppl_model, leafcontext(dppl_model.context))
122+
# Then reconstruct the Sampleable
123+
right = to_submodel(new_dppl_model, true)
124+
# We can splice that prefixed submodel context into the parent context.
125+
original_parent_leafcontext = leafcontext(context)
126+
new_parent_context = setleafcontext(context, prefixed_submodel_context)
127+
new_parent_context = setleafcontext(
128+
new_parent_context, original_parent_leafcontext
129+
)
130+
else
131+
# If we don't need to prefix, we can just use the parent context as-is
132+
new_parent_context = context
133+
end
134+
rand_like!!(right, new_parent_context, vi)
113135
else
114136
value, logp, vi = tilde_assume(context, right, vn, vi)
115137
value, acclogp_assume!!(context, vi, logp)

src/contexts.jl

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ end
265265
266266
Apply the prefixes in the context `ctx` to the variable name `vn`.
267267
"""
268-
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
269-
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}())
268+
function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix}
269+
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}())
270270
end
271271
function prefix(ctx::AbstractContext, vn::VarName)
272272
return prefix(NodeTrait(ctx), ctx, vn)
@@ -351,6 +351,43 @@ NodeTrait(::ConditionContext) = IsParent()
351351
childcontext(context::ConditionContext) = context.context
352352
setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child)
353353

354+
"""
355+
prefix_conditioned_variables(context::AbstractContext, prefix::VarName)
356+
357+
Prefix all the conditioned variables in a given context with `prefix`.
358+
359+
```jldoctest
360+
julia> using DynamicPPL: prefix_conditioned_variables, ConditionContext
361+
362+
julia> c1 = ConditionContext((a=1, ))
363+
ConditionContext((a = 1,), DefaultContext())
364+
365+
julia> prefix_conditioned_variables(c1, @varname(y))
366+
ConditionContext(Dict(y.a => 1), DefaultContext())
367+
```
368+
"""
369+
function prefix_conditioned_variables(ctx::ConditionContext, prefix::VarName)
370+
# Replace the prefix of the conditioned variables
371+
vn_dict = to_varname_dict(ctx.values)
372+
prefixed_vn_dict = Dict(
373+
AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict
374+
)
375+
# Prefix the child context as well
376+
prefixed_child_ctx = prefix_conditioned_variables(childcontext(ctx), prefix)
377+
return ConditionContext(prefixed_vn_dict, prefixed_child_ctx)
378+
end
379+
function prefix_conditioned_variables(c::AbstractContext, prefix::VarName)
380+
return prefix_conditioned_variables(
381+
NodeTrait(prefix_conditioned_variables, c), c, prefix
382+
)
383+
end
384+
prefix_conditioned_variables(::IsLeaf, context::AbstractContext, prefix::VarName) = context
385+
function prefix_conditioned_variables(::IsParent, context::AbstractContext, prefix::VarName)
386+
return setchildcontext(
387+
context, prefix_conditioned_variables(childcontext(context), prefix)
388+
)
389+
end
390+
354391
"""
355392
hasconditioned(context::AbstractContext, vn::VarName)
356393
@@ -370,7 +407,9 @@ Return value of `vn` in `context`.
370407
function getconditioned(context::AbstractContext, vn::VarName)
371408
return error("context $(context) does not contain value for $vn")
372409
end
373-
getconditioned(context::ConditionContext, vn::VarName) = getvalue(context.values, vn)
410+
function getconditioned(context::ConditionContext, vn::VarName)
411+
return getvalue(context.values, vn)
412+
end
374413

375414
"""
376415
hasconditioned_nested(context, vn)
@@ -387,8 +426,10 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn)
387426
function hasconditioned_nested(::IsParent, context, vn)
388427
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
389428
end
390-
function hasconditioned_nested(context::PrefixContext, vn)
391-
return hasconditioned_nested(childcontext(context), prefix(context, vn))
429+
function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
430+
return hasconditioned_nested(
431+
prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn
432+
)
392433
end
393434

394435
"""
@@ -405,8 +446,10 @@ end
405446
function getconditioned_nested(::IsLeaf, context, vn)
406447
return error("context $(context) does not contain value for $vn")
407448
end
408-
function getconditioned_nested(context::PrefixContext, vn)
409-
return getconditioned_nested(childcontext(context), prefix(context, vn))
449+
function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix}
450+
return getconditioned_nested(
451+
prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn
452+
)
410453
end
411454
function getconditioned_nested(::IsParent, context, vn)
412455
return if hasconditioned(context, vn)

src/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,17 +1286,20 @@ broadcast_safe(x::Distribution) = (x,)
12861286
broadcast_safe(x::AbstractContext) = (x,)
12871287

12881288
# Convert (x=1,) to Dict(@varname(x) => 1)
1289-
_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt))
1289+
function to_varname_dict(nt::NamedTuple)
1290+
return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt))
1291+
end
1292+
to_varname_dict(d::AbstractDict) = d
12901293
# Version of `merge` used by `conditioned` and `fixed` to handle
12911294
# the scenario where we might try to merge a dict with an empty
12921295
# tuple.
12931296
# TODO: Maybe replace the default of returning `NamedTuple` with `nothing`?
12941297
_merge(left::NamedTuple, right::NamedTuple) = merge(left, right)
12951298
_merge(left::AbstractDict, right::AbstractDict) = merge(left, right)
12961299
_merge(left::AbstractDict, ::NamedTuple{()}) = left
1297-
_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right))
1300+
_merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(right))
12981301
_merge(::NamedTuple{()}, right::AbstractDict) = right
1299-
_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right)
1302+
_merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right)
13001303

13011304
"""
13021305
unique_syms(vns::T) where {T<:NTuple{N,VarName}}

0 commit comments

Comments
 (0)