Skip to content

Commit e6f1068

Browse files
committed
Avoid pirating AbstractPPL.prefix
1 parent 49a3123 commit e6f1068

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

src/compiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(
6666
# TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like
6767
# the whole `isassumption` thing to be simplified, though, so I'll
6868
# leave it till later.
69-
$vn = $(AbstractPPL.prefix)($left_vn, __model__.prefix)
69+
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
7070
if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn)
7171
# Considered an assumption by `__model__.context` which means either:
7272
# 1. We hit the default implementation, e.g. using `DefaultContext`,
@@ -454,7 +454,7 @@ function generate_tilde(left, right)
454454
return quote
455455
$dist = $right
456456
$left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
457-
$vn = $(AbstractPPL.prefix)($left_vn, __model__.prefix)
457+
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
458458
$isassumption = $(DynamicPPL.isassumption(left, left_vn))
459459
if $(DynamicPPL.isfixed(left, vn))
460460
$left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn)

src/prefix.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
"""
2-
AbstractPPL.prefix(vn::VarName, ::Nothing)
3-
AbstractPPL.prefix(::Nothing, vn::VarName)
2+
maybe_prefix(inner::Union{Nothing,<:VarName}, outer::Union{Nothing,<:VarName})
43
5-
Return the original `vn` (i.e., prefixed with nothing).
4+
Prefix `inner` with the prefix `outer`. Both `inner` and `outer` can be either
5+
`VarName`s or `Nothing`.
66
7-
These cases only happen in DynamicPPL and are thus handled here (AbstractPPL's
8-
definition works only if both arguments are `VarName`s).
7+
Note that this differs from `AbstractPPL.prefix` in that it handles `nothing` values.
8+
This can happen e.g. when prefixing a model that is not already prefixed; or when
9+
executing submodels without automatic prefixing.
910
"""
10-
AbstractPPL.prefix(vn::VarName, ::Nothing) = vn
11-
AbstractPPL.prefix(::Nothing, vn::VarName) = vn
11+
maybe_prefix(inner::VarName, outer::VarName) = AbstractPPL.prefix(inner, outer)
12+
maybe_prefix(vn::VarName, ::Nothing) = vn
13+
maybe_prefix(::Nothing, vn::VarName) = vn
14+
maybe_prefix(::Nothing, ::Nothing) = nothing
1215

1316
"""
1417
prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName)
@@ -94,11 +97,11 @@ julia> rand(prefix(demo(), Val(:my_prefix)))
9497
```
9598
"""
9699
prefix(model::Model, ::Nothing) = model
97-
function prefix(model::Model, x::VarName)
100+
function prefix(model::Model, vn::VarName)
98101
# Add it to the model prefix field
99-
new_prefix = AbstractPPL.prefix(model.prefix, x)
102+
new_prefix = maybe_prefix(model.prefix, vn)
100103
# And also make sure to prefix any conditioned and fixed variables stored in the model
101-
new_context = prefix_cond_and_fixed_variables(model.context, x)
104+
new_context = prefix_cond_and_fixed_variables(model.context, vn)
102105
return Model(model.f, model.args, model.defaults, new_context, new_prefix)
103106
end
104107
prefix(model::Model, ::Val{sym}) where {sym} = prefix(model, VarName{sym}())

src/submodel.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ function _evaluate!!(
164164
vi::AbstractVarInfo,
165165
parent_context::AbstractContext,
166166
parent_prefix::Union{Nothing,<:VarName},
167-
left_vn::VarName,
167+
vn::VarName,
168168
) where {M<:Model,AutoPrefix}
169169
# First, we construct the context to be used when evaluating the submodel. There
170170
# are several considerations here:
@@ -173,7 +173,9 @@ function _evaluate!!(
173173
# automatic prefixing if it was requested. (If the prefix was manually applied, then
174174
# `prefix()` will have been called by the user, and we don't need to do it again.)
175175
submodel_prefix = if AutoPrefix
176-
AbstractPPL.prefix(left_vn, parent_prefix)
176+
# Note that by the time we see it here (in `tilde_assume!!`), `vn`
177+
# has already prefixed with `parent_prefix`, so no need to re-prefix it
178+
vn
177179
else
178180
parent_prefix
179181
end

test/submodels.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,54 @@ using Test
133133
end
134134
end
135135

136-
@testset "Nested submodels" begin
136+
@testset "Nested submodels with auto prefix" begin
137+
@model function f()
138+
x ~ Normal()
139+
return y ~ Normal()
140+
end
141+
@model function g()
142+
return b ~ to_submodel(f())
143+
end
144+
@model function h()
145+
return a ~ to_submodel(g())
146+
end
147+
148+
# No conditioning
149+
vi = VarInfo(h())
150+
@test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)])
151+
@test getlogjoint(vi) ==
152+
logpdf(Normal(), vi[@varname(a.b.x)]) +
153+
logpdf(Normal(), vi[@varname(a.b.y)])
154+
155+
# Conditioning/fixing at the top level
156+
op_h = op(h(), (@varname(a.b.x) => x_val))
157+
158+
# Conditioning/fixing at the second level
159+
op_g = op(g(), (@varname(b.x) => x_val))
160+
@model function h2()
161+
return a ~ to_submodel(op_g)
162+
end
163+
164+
# Conditioning/fixing at the very bottom
165+
op_f = op(f(), (@varname(x) => x_val))
166+
@model function g2()
167+
return _unused ~ to_submodel(prefix(op_f, :b), false)
168+
end
169+
@model function h3()
170+
return a ~ to_submodel(g2())
171+
end
172+
173+
models = [("top", op_h), ("middle", h2()), ("bottom", h3())]
174+
@testset "$name" for (name, model) in models
175+
vi = VarInfo(model)
176+
@test Set(keys(vi)) == Set([@varname(a.b.y)])
177+
@test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)])
178+
end
179+
end
180+
181+
@testset "Nested submodels with manual prefix" begin
182+
# Same tests as above, just that the middle layer has manual prefixing
183+
# rather than automatic.
137184
@model function f()
138185
x ~ Normal()
139186
return y ~ Normal()

0 commit comments

Comments
 (0)