Skip to content

Commit cf87ce7

Browse files
committed
Move PrefixContext to a model field
1 parent 331279c commit cf87ce7

File tree

12 files changed

+246
-563
lines changed

12 files changed

+246
-563
lines changed

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.
464464

465465
```@docs
466466
DefaultContext
467-
PrefixContext
468467
ConditionContext
469468
InitContext
470469
```

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ export AbstractVarInfo,
102102
# Contexts
103103
contextualize,
104104
DefaultContext,
105-
PrefixContext,
106105
ConditionContext,
107106
# Tilde pipeline
108107
tilde_assume!!,
@@ -177,6 +176,7 @@ include("chains.jl")
177176
include("contexts.jl")
178177
include("contexts/init.jl")
179178
include("model.jl")
179+
include("prefix.jl")
180180
include("sampler.jl")
181181
include("varname.jl")
182182
include("distribution_wrappers.jl")

src/compiler.jl

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@ evaluates to a `VarName`, and this will be used in the subsequent checks.
6060
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
6161
used in its place.
6262
"""
63-
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
63+
function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr))
64+
@gensym vn
6465
return quote
65-
if $(DynamicPPL.contextual_isassumption)(
66-
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
67-
)
66+
# TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like
67+
# the whole `isassumption` thing to be simplified, though, so I'll
68+
# leave it till later.
69+
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
70+
if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn)
6871
# Considered an assumption by `__model__.context` which means either:
6972
# 1. We hit the default implementation, e.g. using `DefaultContext`,
7073
# which in turn means that we haven't considered if it's one of
@@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)
7881
# TODO: Support by adding context to model, and use `model.args`
7982
# as the default conditioning. Then we no longer need to check `inargnames`
8083
# since it will all be handled by `contextual_isassumption`.
81-
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
82-
$(DynamicPPL.inmissings)($vn, __model__)
84+
if !($(DynamicPPL.inargnames)($left_vn, __model__)) ||
85+
$(DynamicPPL.inmissings)($left_vn, __model__)
8386
true
8487
else
8588
$(maybe_view(expr)) === missing
@@ -99,7 +102,7 @@ isassumption(expr) = :(false)
99102
100103
Return `true` if `vn` is considered an assumption by `context`.
101104
"""
102-
function contextual_isassumption(context::AbstractContext, vn)
105+
function contextual_isassumption(context::AbstractContext, vn::VarName)
103106
if hasconditioned_nested(context, vn)
104107
val = getconditioned_nested(context, vn)
105108
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
@@ -115,9 +118,7 @@ end
115118

116119
isfixed(expr, vn) = false
117120
function isfixed(::Union{Symbol,Expr}, vn)
118-
return :($(DynamicPPL.contextual_isfixed)(
119-
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
120-
))
121+
return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn))
121122
end
122123

123124
"""
@@ -413,7 +414,9 @@ function generate_assign(left, right)
413414
return quote
414415
$right_val = $right
415416
if $(DynamicPPL.is_extracting_values)(__varinfo__)
416-
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
417+
$vn = $(DynamicPPL.maybe_prefix)(
418+
$(make_varname_expression(left)), __model__.prefix
419+
)
417420
__varinfo__ = $(map_accumulator!!)(
418421
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
419422
)
@@ -448,24 +451,23 @@ function generate_tilde(left, right)
448451

449452
# Otherwise it is determined by the model or its value,
450453
# if the LHS represents an observation
451-
@gensym vn isassumption value dist
454+
@gensym left_vn vn isassumption value dist
452455

453456
return quote
454457
$dist = $right
455-
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
456-
$isassumption = $(DynamicPPL.isassumption(left, vn))
458+
$left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
459+
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
460+
$isassumption = $(DynamicPPL.isassumption(left, left_vn))
457461
if $(DynamicPPL.isfixed(left, vn))
458-
$left = $(DynamicPPL.getfixed_nested)(
459-
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
460-
)
462+
$left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn)
461463
elseif $isassumption
462464
$(generate_tilde_assume(left, dist, vn))
463465
else
464-
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
465-
if !$(DynamicPPL.inargnames)($vn, __model__)
466-
$left = $(DynamicPPL.getconditioned_nested)(
467-
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
468-
)
466+
# If `left_vn` is not in `argnames`, we need to make sure that the variable is defined.
467+
# (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had
468+
# prefixes applied!)
469+
if !$(DynamicPPL.inargnames)($left_vn, __model__)
470+
$left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn)
469471
end
470472

471473
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
@@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn)
495497
return quote
496498
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
497499
__model__.context,
500+
__model__.prefix,
498501
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
499502
__varinfo__,
500503
)

src/context_implementations.jl

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,30 @@
11
# assume
2-
function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi)
3-
return tilde_assume!!(childcontext(context), right, vn, vi)
2+
function tilde_assume!!(context::AbstractContext, prefix, right::Distribution, vn, vi)
3+
return tilde_assume!!(childcontext(context), prefix, right, vn, vi)
44
end
5-
function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi)
5+
function tilde_assume!!(::DefaultContext, prefix, right::Distribution, vn, vi)
66
y = getindex_internal(vi, vn)
77
f = from_maybe_linked_internal_transform(vi, vn, right)
88
x, inv_logjac = with_logabsdet_jacobian(f, y)
99
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)
1010
return x, vi
1111
end
12-
function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
13-
# Note that we can't use something like this here:
14-
# new_vn = prefix(context, vn)
15-
# return tilde_assume!!(childcontext(context), right, new_vn, vi)
16-
# This is because `prefix` applies _all_ prefixes in a given context to a
17-
# variable name. Thus, if we had two levels of nested prefixes e.g.
18-
# `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the
19-
# first call would apply the prefix `a.b._`, and the recursive call
20-
# would apply the prefix `b._`, resulting in `b.a.b._`.
21-
# This is why we need a special function, `prefix_and_strip_contexts`.
22-
new_vn, new_context = prefix_and_strip_contexts(context, vn)
23-
return tilde_assume!!(new_context, right, new_vn, vi)
24-
end
2512

2613
"""
27-
tilde_assume!!(context, right, vn, vi)
14+
tilde_assume!!(context, prefix, right, vn, vi)
2815
2916
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
3017
accumulate the log probability, and return the sampled value and updated `vi`.
3118
"""
32-
function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi)
33-
return _evaluate!!(right, vi, context, vn)
19+
function tilde_assume!!(context, prefix, right::DynamicPPL.Submodel, vn, vi)
20+
return _evaluate!!(right, vi, context, prefix, vn)
3421
end
3522

3623
# observe
3724
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
3825
return tilde_observe!!(childcontext(context), right, left, vn, vi)
3926
end
4027

41-
# `PrefixContext`
42-
function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
43-
# In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
44-
# value. For the need for prefix_and_strip_contexts rather than just prefix, see the
45-
# comment in `tilde_assume!!`.
46-
new_vn, new_context = if vn !== nothing
47-
prefix_and_strip_contexts(context, vn)
48-
else
49-
vn, childcontext(context)
50-
end
51-
return tilde_observe!!(new_context, right, left, new_vn, vi)
52-
end
53-
5428
"""
5529
tilde_observe!!(context, right, left, vn, vi)
5630

0 commit comments

Comments
 (0)