Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.0"
version = "0.33.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

```@docs
|(::Model, ::Any)
|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}})
condition
DynamicPPL.conditioned
```
Expand Down Expand Up @@ -403,6 +403,7 @@ LikelihoodContext
PriorContext
MiniBatchContext
PrefixContext
ConditionContext
```

### Samplers
Expand Down
120 changes: 44 additions & 76 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,40 @@ function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
"""

ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}

Model context that contains values that are to be conditioned on. The values
can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
@varname(b) => 2)`). The former is more performant, but the latter must be used
when there are varnames that cannot be represented as symbols, e.g.
`@varname(x[1])`.
"""
struct ConditionContext{
Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext
} <: AbstractContext
values::Values
context::Ctx
end

const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
const DictConditionContext = ConditionContext{<:AbstractDict}

ConditionContext(values) = ConditionContext(values, DefaultContext())

# Try to avoid nested `ConditionContext`.
# Use DefaultContext as the default base context
function ConditionContext(values::Union{NamedTuple,AbstractDict})
return ConditionContext(values, DefaultContext())
end
# Optimisation when there are no values to condition on
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
# values inside the child context, thus giving precedence to the outermost
# `ConditionContext`.
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), childcontext(context))
end
function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext)
return ConditionContext(merge(context.values, values), childcontext(context))
end

Expand Down Expand Up @@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn)
end
end

"""
condition([context::AbstractContext,] values::NamedTuple)
condition([context::AbstractContext]; values...)

Return `ConditionContext` with `values` and `context` if `values` is non-empty,
otherwise return `context` which is [`DefaultContext`](@ref) by default.

See also: [`decondition`](@ref)
"""
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
return condition((value, values...))
end
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
return condition(DefaultContext(), values)
end
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
function AbstractPPL.condition(
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
)
return ConditionContext(values, context)
end
function AbstractPPL.condition(context::AbstractContext; values...)
return condition(context, NamedTuple(values))
end
function AbstractPPL.condition(
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
)
return condition(context, (value, values...))
end
function AbstractPPL.condition(
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
)
return condition(context, Dict(values))
end

"""
decondition(context::AbstractContext, syms...)

Expand All @@ -445,42 +428,27 @@ Note that this recursively traverses contexts, deconditioning all along the way.

See also: [`condition`](@ref)
"""
AbstractPPL.decondition(::IsLeaf, context, args...) = context
function AbstractPPL.decondition(::IsParent, context, args...)
return setchildcontext(context, decondition(childcontext(context), args...))
end
function AbstractPPL.decondition(context, args...)
return decondition(NodeTrait(context), context, args...)
end
function AbstractPPL.decondition(context::ConditionContext)
return decondition(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym)
return condition(
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
)
decondition_context(::IsLeaf, context, args...) = context
function decondition_context(::IsParent, context, args...)
return setchildcontext(context, decondition_context(childcontext(context), args...))
end
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
return decondition(
condition(
decondition(childcontext(context), syms...),
BangBang.delete!!(context.values, sym),
),
syms...,
)
function decondition_context(context, args...)
return decondition_context(NodeTrait(context), context, args...)
end

function AbstractPPL.decondition(
context::NamedConditionContext, vn::VarName{sym}
) where {sym}
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext)
return decondition_context(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
)
function decondition_context(context::ConditionContext, sym, syms...)
new_values = deepcopy(context.values)
for s in (sym, syms...)
new_values = BangBang.delete!!(new_values, s)
end
return if length(new_values) == 0
# No more values left, can unwrap
decondition_context(childcontext(context), syms...)
else
ConditionContext(new_values, decondition_context(childcontext(context), syms...))
end
end

"""
Expand Down
30 changes: 26 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Return a `Model` which now treats variables on the right-hand side as observatio

See [`condition`](@ref) for more information and examples.
"""
Base.:|(model::Model, values) = condition(model, values)
Base.:|(model::Model, values::Union{Tuple,NamedTuple,AbstractDict{<:VarName}}) =
condition(model, values)

"""
condition(model::Model; values...)
Expand Down Expand Up @@ -264,11 +265,32 @@ julia> conditioned_model_dict()
1.0
```
"""
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
function AbstractPPL.condition(model::Model, value, values...)
return contextualize(model, condition(model.context, value, values...))
function AbstractPPL.condition(model::Model, values...)
# Positional arguments - need to handle cases carefully
return contextualize(
model, ConditionContext(_make_conditioning_values(values...), model.context)
)
end
function AbstractPPL.condition(model::Model; values...)
# Keyword arguments -- just convert to a NamedTuple
return contextualize(model, ConditionContext(NamedTuple(values), model.context))
end

"""
_make_conditioning_values(vals...)

Convert different types of input to either a `NamedTuple` or `AbstractDict` of
conditioning values, suitable for storage in a `ConditionContext`.
"""
# Case 1: Already in the right format, e.g. condition(model, (x=1, y=2))
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
# Case 2: condition(model, (@varname(x) => 1, @varname(y) => 2))
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
# Case 3: Case 1 but splatted, e.g. condition(model, x=1, y=2)
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)
# Case 4: Case 2 but splatted, e.g. condition(model, @varname(x) => 1, @varname(y) => 2)
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)

"""
decondition(model::Model)
decondition(model::Model, variables...)
Expand Down
66 changes: 66 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using DynamicPPL:
PointwiseLogdensityContext,
contextual_isassumption,
ConditionContext,
decondition_context,
hasconditioned,
getconditioned,
hasconditioned_nested,
Expand Down Expand Up @@ -196,6 +197,71 @@ end
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "ConditionContext" begin
@testset "Nesting" begin
@testset "NamedTuple" begin
n1 = (x=1, y=2)
n2 = (x=3,)
# Values from outer context should override inner one
ctx1 = ConditionContext(n1, ConditionContext(n2))
@test ctx1.values == (x=1, y=2)
# Check that the two ConditionContexts are collapsed
@test childcontext(ctx1) isa DefaultContext
# Then test the nesting the other way round
ctx2 = ConditionContext(n2, ConditionContext(n1))
@test ctx2.values == (x=3, y=2)
@test childcontext(ctx2) isa DefaultContext
end

@testset "Dict" begin
# Same tests as NamedTuple above
d1 = Dict(@varname(x) => 1, @varname(y) => 2)
d2 = Dict(@varname(x) => 3)
ctx1 = ConditionContext(d1, ConditionContext(d2))
@test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2)
@test childcontext(ctx1) isa DefaultContext
ctx2 = ConditionContext(d2, ConditionContext(d1))
@test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2)
@test childcontext(ctx2) isa DefaultContext
end
end

@testset "decondition_context" begin
@testset "NamedTuple" begin
ctx = ConditionContext((x=1, y=2, z=3))
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, :x)
@test dctx isa ConditionContext
@test dctx.values == (y=2, z=3)
dctx = decondition_context(ctx, :y, :z)
@test dctx isa ConditionContext
@test dctx.values == (x=1,)
# Decondition all variables manually
@test decondition_context(ctx, :x, :y, :z) isa DefaultContext
end

@testset "Dict" begin
ctx = ConditionContext(
Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3)
)
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, @varname(x))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3)
dctx = decondition_context(ctx, @varname(y), @varname(z))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(x) => 1)
# Decondition all variables manually
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa
DefaultContext
end
end
end

@testset "FixedContext" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
retval = model()
Expand Down
22 changes: 22 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,28 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end

@testset "model conditioning with various arguments" begin
@model function demo_condition()
x ~ Normal()
return y ~ Normal(x)
end
model = demo_condition()
# Test that different syntaxes work and give the same underlying ConditionContext
@testset "NamedTuple ConditionContext" begin
expected_values = (y=2,)
@test condition(model, (y=2,)).context.values == expected_values
@test condition(model; y=2).context.values == expected_values
@test condition(model; y=2).context.values == expected_values
@test (model | (y=2,)).context.values == expected_values
end
@testset "AbstractDict ConditionContext" begin
expected_values = Dict(@varname(y) => 2)
@test condition(model, Dict(@varname(y) => 2)).context.values == expected_values
@test condition(model, @varname(y) => 2).context.values == expected_values
@test (model | (@varname(y) => 2,)).context.values == expected_values
end
end

@testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin
@model function multiple_types(x)
ns ~ filldist(Normal(0, 2.0), 3)
Expand Down
Loading