Skip to content

Commit ecc3c02

Browse files
torfjeldeyebai
andcommitted
condition and decondition (#301)
This PR introduces `condition` and `decondition`. This is really just a reopening of #294 that I can't reopen directly due to the target branch now being deleted. Co-authored-by: Hong Ge <[email protected]>
1 parent ea658b5 commit ecc3c02

File tree

9 files changed

+708
-26
lines changed

9 files changed

+708
-26
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.13.2"
3+
version = "0.14.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
8+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
89
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1516
[compat]
1617
AbstractMCMC = "2, 3.0"
1718
AbstractPPL = "0.2"
19+
BangBang = "0.3"
1820
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
1921
ChainRulesCore = "0.9.7, 0.10"
2022
Distributions = "0.23.8, 0.24, 0.25"

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC
99
using ChainRulesCore: ChainRulesCore
1010
using MacroTools: MacroTools
1111
using ZygoteRules: ZygoteRules
12+
using BangBang: BangBang
1213

1314
using Random: Random
1415

@@ -81,6 +82,7 @@ export AbstractVarInfo,
8182
PriorContext,
8283
MiniBatchContext,
8384
PrefixContext,
85+
ConditionContext,
8486
assume,
8587
dot_assume,
8688
observe,
@@ -99,6 +101,8 @@ export AbstractVarInfo,
99101
logprior,
100102
logjoint,
101103
pointwise_loglikelihoods,
104+
condition,
105+
decondition,
102106
# Convenience macros
103107
@addlogprob!,
104108
@submodel

src/compiler.jl

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,66 @@ function isassumption(expr::Union{Symbol,Expr})
2020

2121
return quote
2222
let $vn = $(varname(expr))
23-
# This branch should compile nicely in all cases except for partial missing data
24-
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
25-
if !$(DynamicPPL.inargnames)($vn, __model__) ||
26-
$(DynamicPPL.inmissings)($vn, __model__)
27-
true
23+
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
24+
# Considered an assumption by `__context__` which means either:
25+
# 1. We hit the default implementation, e.g. using `DefaultContext`,
26+
# which in turn means that we haven't considered if it's one of
27+
# the model arguments, hence we need to check this.
28+
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
29+
# i.e. we're trying to condition one of the latent variables.
30+
# In this case, the below will return `true` since the first branch
31+
# will be hit.
32+
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
33+
# i.e. we're trying to override the value. This is currently NOT supported.
34+
# TODO: Support by adding context to model, and use `model.args`
35+
# as the default conditioning. Then we no longer need to check `inargnames`
36+
# since it will all be handled by `contextual_isassumption`.
37+
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
38+
$(DynamicPPL.inmissings)($vn, __model__)
39+
true
40+
else
41+
$(maybe_view(expr)) === missing
42+
end
2843
else
29-
# Evaluate the LHS
30-
$(maybe_view(expr)) === missing
44+
false
3145
end
3246
end
3347
end
3448
end
3549

50+
"""
51+
contextual_isassumption(context, vn)
52+
53+
Return `true` if `vn` is considered an assumption by `context`.
54+
55+
The default implementation for `AbstractContext` always returns `true`.
56+
"""
57+
contextual_isassumption(::IsLeaf, context, vn) = true
58+
function contextual_isassumption(::IsParent, context, vn)
59+
return contextual_isassumption(childcontext(context), vn)
60+
end
61+
function contextual_isassumption(context::AbstractContext, vn)
62+
return contextual_isassumption(NodeTrait(context), context, vn)
63+
end
64+
function contextual_isassumption(context::ConditionContext, vn)
65+
if hasvalue(context, vn)
66+
val = getvalue(context, vn)
67+
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
68+
if eltype(val) >: Missing && val === missing
69+
return true
70+
else
71+
return false
72+
end
73+
end
74+
75+
# We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`
76+
# so we defer to `childcontext` if we haven't concluded that anything yet.
77+
return contextual_isassumption(childcontext(context), vn)
78+
end
79+
function contextual_isassumption(context::PrefixContext, vn)
80+
return contextual_isassumption(childcontext(context), prefix(context, vn))
81+
end
82+
3683
# failsafe: a literal is never an assumption
3784
isassumption(expr) = :(false)
3885

@@ -93,7 +140,7 @@ variables.
93140
94141
# Example
95142
```jldoctest; setup=:(using Distributions)
96-
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); string(vns[end])
143+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal([1.0, 1.0], [1.0 0.0; 0.0 1.0]), randn(2, 2), @varname(x)); string(vns[end])
97144
"x[:,2]"
98145
99146
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end])
@@ -351,6 +398,11 @@ function generate_tilde(left, right)
351398
__varinfo__,
352399
)
353400
else
401+
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
402+
if !$(DynamicPPL.inargnames)($vn, __model__)
403+
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
404+
end
405+
354406
$(DynamicPPL.tilde_observe!)(
355407
__context__,
356408
$(DynamicPPL.check_tilde_rhs)($right),
@@ -395,6 +447,11 @@ function generate_dot_tilde(left, right)
395447
__varinfo__,
396448
)
397449
else
450+
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
451+
if !$(DynamicPPL.inargnames)($vn, __model__)
452+
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
453+
end
454+
398455
$(DynamicPPL.dot_tilde_observe!)(
399456
__context__,
400457
$(DynamicPPL.check_tilde_rhs)($right),

src/context_implementations.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
1414
require_gradient(spl::Sampler) = false
1515
require_particles(spl::Sampler) = false
1616

17-
_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds))
17+
_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds))
1818
_getindex(x, inds::Tuple{}) = x
19+
_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing)
20+
function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym}
21+
val = getproperty(x, sym)
22+
23+
# This should work with both cartesian and linear indexing.
24+
return map(vns) do vn
25+
_getindex(val, vn)
26+
end
27+
end
1928

2029
# assume
2130
"""
@@ -177,13 +186,14 @@ tilde_observe(::PriorContext, sampler, right, left, vi) = 0
177186
function tilde_observe(context::MiniBatchContext, right, left, vi)
178187
return context.loglike_scalar * tilde_observe(context.context, right, left, vi)
179188
end
180-
function tilde_observe(context::MiniBatchContext, right, left, vname, vi)
181-
return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi)
189+
function tilde_observe(context::MiniBatchContext, sampler, right, left, vi)
190+
return context.loglike_scalar *
191+
tilde_observe(context.context, sampler, right, left, vname, vi)
182192
end
183193

184194
# `PrefixContext`
185-
function tilde_observe(context::PrefixContext, right, left, vname, vi)
186-
return tilde_observe(context.context, right, left, prefix(context, vname), vi)
195+
function tilde_observe(context::PrefixContext, right, left, vi)
196+
return tilde_observe(context.context, right, left, vi)
187197
end
188198

189199
"""

src/contexts.jl

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,176 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
251251
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
252252
end
253253
end
254+
255+
struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext
256+
values::Values
257+
context::Ctx
258+
259+
function ConditionContext{Values}(
260+
values::Values, context::AbstractContext
261+
) where {names,Values<:NamedTuple{names}}
262+
return new{names,typeof(values),typeof(context)}(values, context)
263+
end
264+
end
265+
266+
function ConditionContext(values::NamedTuple)
267+
return ConditionContext(values, DefaultContext())
268+
end
269+
function ConditionContext(values::NamedTuple, context::AbstractContext)
270+
return ConditionContext{typeof(values)}(values, context)
271+
end
272+
273+
# Try to avoid nested `ConditionContext`.
274+
function ConditionContext(
275+
values::NamedTuple{Names}, context::ConditionContext
276+
) where {Names}
277+
# Note that this potentially overrides values from `context`, thus giving
278+
# precedence to the outmost `ConditionContext`.
279+
return ConditionContext(merge(context.values, values), childcontext(context))
280+
end
281+
282+
function Base.show(io::IO, context::ConditionContext)
283+
return print(io, "ConditionContext($(context.values), $(childcontext(context)))")
284+
end
285+
286+
NodeTrait(context::ConditionContext) = IsParent()
287+
childcontext(context::ConditionContext) = context.context
288+
setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child)
289+
290+
"""
291+
hasvalue(context, vn)
292+
293+
Return `true` if `vn` is found in `context`.
294+
"""
295+
hasvalue(context, vn) = false
296+
297+
function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
298+
return sym in vars
299+
end
300+
function hasvalue(
301+
context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}
302+
) where {vars,sym}
303+
return sym in vars
304+
end
305+
306+
"""
307+
getvalue(context, vn)
308+
309+
Return value of `vn` in `context`.
310+
"""
311+
function getvalue(context::AbstractContext, vn)
312+
return error("context $(context) does not contain value for $vn")
313+
end
314+
getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn)
315+
316+
"""
317+
hasvalue_nested(context, vn)
318+
319+
Return `true` if `vn` is found in `context` or any of its descendants.
320+
321+
This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`,
322+
not recursively checking if `vn` is in any of its descendants.
323+
"""
324+
function hasvalue_nested(context::AbstractContext, vn)
325+
return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn)
326+
end
327+
hasvalue_nested(::IsLeaf, context, vn) = hasvalue(context, vn)
328+
function hasvalue_nested(::IsParent, context, vn)
329+
return hasvalue(context, vn) || hasvalue_nested(childcontext(context), vn)
330+
end
331+
function hasvalue_nested(context::PrefixContext, vn)
332+
return hasvalue_nested(childcontext(context), prefix(context, vn))
333+
end
334+
335+
"""
336+
getvalue_nested(context, vn)
337+
338+
Return the value of the parameter corresponding to `vn` from `context` or its descendants.
339+
340+
This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`,
341+
not recursively looking into its descendants.
342+
"""
343+
function getvalue_nested(context::AbstractContext, vn)
344+
return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn)
345+
end
346+
function getvalue_nested(::IsLeaf, context, vn)
347+
return error("context $(context) does not contain value for $vn")
348+
end
349+
function getvalue_nested(context::PrefixContext, vn)
350+
return getvalue_nested(childcontext(context), prefix(context, vn))
351+
end
352+
function getvalue_nested(::IsParent, context, vn)
353+
return if hasvalue(context, vn)
354+
getvalue(context, vn)
355+
else
356+
getvalue_nested(childcontext(context), vn)
357+
end
358+
end
359+
360+
"""
361+
condition([context::AbstractContext,] values::NamedTuple)
362+
condition([context::AbstractContext]; values...)
363+
364+
Return `ConditionContext` with `values` and `context` if `values` is non-empty,
365+
otherwise return `context` which is [`DefaultContext`](@ref) by default.
366+
367+
See also: [`decondition`](@ref)
368+
"""
369+
condition(; values...) = condition(DefaultContext(), NamedTuple(values))
370+
condition(values::NamedTuple) = condition(DefaultContext(), values)
371+
condition(context::AbstractContext, values::NamedTuple{()}) = context
372+
condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context)
373+
condition(context::AbstractContext; values...) = condition(context, NamedTuple(values))
374+
375+
"""
376+
decondition(context::AbstractContext, syms...)
377+
378+
Return `context` but with `syms` no longer conditioned on.
379+
380+
Note that this recursively traverses contexts, deconditioning all along the way.
381+
382+
See also: [`condition`](@ref)
383+
"""
384+
decondition(::IsLeaf, context, args...) = context
385+
function decondition(::IsParent, context, args...)
386+
return setchildcontext(context, decondition(childcontext(context), args...))
387+
end
388+
decondition(context, args...) = decondition(NodeTrait(context), context, args...)
389+
function decondition(context::ConditionContext)
390+
return decondition(childcontext(context))
391+
end
392+
function decondition(context::ConditionContext, sym)
393+
return condition(
394+
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
395+
)
396+
end
397+
function decondition(context::ConditionContext, sym, syms...)
398+
return decondition(
399+
condition(
400+
decondition(childcontext(context), syms...),
401+
BangBang.delete!!(context.values, sym),
402+
),
403+
syms...,
404+
)
405+
end
406+
407+
"""
408+
conditioned(context::AbstractContext)
409+
410+
Return `NamedTuple` of values that are conditioned on under context`.
411+
412+
Note that this will recursively traverse the context stack and return
413+
a merged version of the condition values.
414+
"""
415+
function conditioned(context::AbstractContext)
416+
return conditioned(NodeTrait(conditioned, context), context)
417+
end
418+
conditioned(::IsLeaf, context) = ()
419+
conditioned(::IsParent, context) = conditioned(childcontext(context))
420+
function conditioned(context::ConditionContext)
421+
# Note the order of arguments to `merge`. The behavior of the rest of DPPL
422+
# is that the outermost `context` takes precendence, hence when resolving
423+
# the `conditioned` variables we need to ensure that `context.values` takes
424+
# precedence over decendants of `context`.
425+
return merge(context.values, conditioned(childcontext(context)))
426+
end

0 commit comments

Comments
 (0)