Skip to content

Commit 9a8a36b

Browse files
committed
Remove condition type piracy
1 parent 003ff2f commit 9a8a36b

File tree

2 files changed

+64
-64
lines changed

2 files changed

+64
-64
lines changed

src/contexts.jl

Lines changed: 38 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,31 @@ function prefix(model::Model, ::Val{x}) where {x}
309309
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
310310
end
311311

312-
struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
312+
"""
313+
314+
ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
315+
316+
Model context that contains values that are to be conditioned on. The values
317+
can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
318+
an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
319+
@varname(b) => 2)`). The former is more performant, but the latter must be used
320+
when there are varnames that cannot be represented as symbols, e.g.
321+
`@varname(x[1])`.
322+
"""
323+
struct ConditionContext{
324+
Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext
325+
} <: AbstractContext
313326
values::Values
314327
context::Ctx
315328
end
316329

317330
const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
318331
const DictConditionContext = ConditionContext{<:AbstractDict}
319332

320-
ConditionContext(values) = ConditionContext(values, DefaultContext())
321-
333+
# Use DefaultContext as the default base context
334+
ConditionContext(values::Union{NamedTuple,AbstractDict}) = ConditionContext(values, DefaultContext())
335+
# Optimisation when there are no values to condition on
336+
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
322337
# Try to avoid nested `ConditionContext`.
323338
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
324339
# Note that this potentially overrides values from `context`, thus giving
@@ -399,43 +414,6 @@ function getconditioned_nested(::IsParent, context, vn)
399414
end
400415
end
401416

402-
"""
403-
condition([context::AbstractContext,] values::NamedTuple)
404-
condition([context::AbstractContext]; values...)
405-
406-
Return `ConditionContext` with `values` and `context` if `values` is non-empty,
407-
otherwise return `context` which is [`DefaultContext`](@ref) by default.
408-
409-
See also: [`decondition`](@ref)
410-
"""
411-
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
412-
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
413-
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
414-
return condition((value, values...))
415-
end
416-
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
417-
return condition(DefaultContext(), values)
418-
end
419-
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
420-
function AbstractPPL.condition(
421-
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
422-
)
423-
return ConditionContext(values, context)
424-
end
425-
function AbstractPPL.condition(context::AbstractContext; values...)
426-
return condition(context, NamedTuple(values))
427-
end
428-
function AbstractPPL.condition(
429-
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
430-
)
431-
return condition(context, (value, values...))
432-
end
433-
function AbstractPPL.condition(
434-
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
435-
)
436-
return condition(context, Dict(values))
437-
end
438-
439417
"""
440418
decondition(context::AbstractContext, syms...)
441419
@@ -445,41 +423,41 @@ Note that this recursively traverses contexts, deconditioning all along the way.
445423
446424
See also: [`condition`](@ref)
447425
"""
448-
AbstractPPL.decondition(::IsLeaf, context, args...) = context
449-
function AbstractPPL.decondition(::IsParent, context, args...)
450-
return setchildcontext(context, decondition(childcontext(context), args...))
426+
decondition_context(::IsLeaf, context, args...) = context
427+
function decondition_context(::IsParent, context, args...)
428+
return setchildcontext(context, decondition_context(childcontext(context), args...))
451429
end
452-
function AbstractPPL.decondition(context, args...)
453-
return decondition(NodeTrait(context), context, args...)
430+
function decondition_context(context, args...)
431+
return decondition_context(NodeTrait(context), context, args...)
454432
end
455-
function AbstractPPL.decondition(context::ConditionContext)
456-
return decondition(childcontext(context))
433+
function decondition_context(context::ConditionContext)
434+
return decondition_context(childcontext(context))
457435
end
458-
function AbstractPPL.decondition(context::ConditionContext, sym)
459-
return condition(
460-
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
436+
function decondition_context(context::ConditionContext, sym)
437+
return ConditionContext(
438+
decondition_context(childcontext(context), sym), BangBang.delete!!(context.values, sym)
461439
)
462440
end
463-
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
464-
return decondition(
465-
condition(
466-
decondition(childcontext(context), syms...),
441+
function decondition_context(context::ConditionContext, sym, syms...)
442+
return decondition_context(
443+
ConditionContext(
444+
decondition_context(childcontext(context), syms...),
467445
BangBang.delete!!(context.values, sym),
468446
),
469447
syms...,
470448
)
471449
end
472450

473-
function AbstractPPL.decondition(
451+
function decondition_context(
474452
context::NamedConditionContext, vn::VarName{sym}
475453
) where {sym}
476-
return condition(
477-
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
454+
return ConditionContext(
455+
decondition_context(childcontext(context), vn), BangBang.delete!!(context.values, sym)
478456
)
479457
end
480-
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
481-
return condition(
482-
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
458+
function decondition_context(context::ConditionContext, vn::VarName)
459+
return ConditionContext(
460+
decondition_context(childcontext(context), vn), BangBang.delete!!(context.values, vn)
483461
)
484462
end
485463

src/model.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ Return a `Model` which now treats variables on the right-hand side as observatio
9696
9797
See [`condition`](@ref) for more information and examples.
9898
"""
99-
Base.:|(model::Model, values) = condition(model, values)
99+
Base.:|(model::Model, values::Union{Tuple,NamedTuple,AbstractDict{<:VarName}}) =
100+
condition(model, values)
100101

101102
"""
102103
condition(model::Model; values...)
@@ -264,11 +265,32 @@ julia> conditioned_model_dict()
264265
1.0
265266
```
266267
"""
267-
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
268-
function AbstractPPL.condition(model::Model, value, values...)
269-
return contextualize(model, condition(model.context, value, values...))
268+
function AbstractPPL.condition(model::Model, values...)
269+
# Positional arguments - need to handle cases carefully
270+
return contextualize(
271+
model, ConditionContext(_make_conditioning_values(values...), model.context)
272+
)
273+
end
274+
function AbstractPPL.condition(model::Model; values...)
275+
# Keyword arguments -- just convert to a NamedTuple
276+
return contextualize(model, ConditionContext(NamedTuple(values), model.context))
270277
end
271278

279+
"""
280+
_make_conditioning_values(vals...)
281+
282+
Convert different types of input to either a `NamedTuple` or `AbstractDict` of
283+
conditioning values, suitable for storage in a `ConditionContext`.
284+
"""
285+
# Case 1: Already in the right format, e.g. condition(model, (x=1, y=2))
286+
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
287+
# Case 2: condition(model, (@varname(x) => 1, @varname(y) => 2))
288+
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
289+
# Case 3: Case 1 but splatted, e.g. condition(model, x=1, y=2)
290+
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)
291+
# Case 4: Case 2 but splatted, e.g. condition(model, @varname(x) => 1, @varname(y) => 2)
292+
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)
293+
272294
"""
273295
decondition(model::Model)
274296
decondition(model::Model, variables...)

0 commit comments

Comments
 (0)