Skip to content

Commit 08ef935

Browse files
committed
Condition with Dict as underlying storage (#419)
This PR allows usage of `Dict` as the underlying storage in addition to the currently supported `NamedTuple`. Similarly to `SimpleVarInfo`, this gives us two approaches: one with somewhat limited support, as outlined in the docstring, but with (usually) zero runtime overhead (`NamedTuple`), and one with full support but with runtime overead (`Dict`).
1 parent 0ba86e2 commit 08ef935

File tree

5 files changed

+313
-128
lines changed

5 files changed

+313
-128
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.20.0"
3+
version = "0.20.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/contexts.jl

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -272,28 +272,18 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
272272
end
273273
end
274274

275-
struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext
275+
struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
276276
values::Values
277277
context::Ctx
278-
279-
function ConditionContext{Values}(
280-
values::Values, context::AbstractContext
281-
) where {names,Values<:NamedTuple{names}}
282-
return new{names,typeof(values),typeof(context)}(values, context)
283-
end
284278
end
285279

286-
function ConditionContext(values::NamedTuple)
287-
return ConditionContext(values, DefaultContext())
288-
end
289-
function ConditionContext(values::NamedTuple, context::AbstractContext)
290-
return ConditionContext{typeof(values)}(values, context)
291-
end
280+
const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
281+
const DictConditionContext = ConditionContext{<:AbstractDict}
282+
283+
ConditionContext(values) = ConditionContext(values, DefaultContext())
292284

293285
# Try to avoid nested `ConditionContext`.
294-
function ConditionContext(
295-
values::NamedTuple{Names}, context::ConditionContext
296-
) where {Names}
286+
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
297287
# Note that this potentially overrides values from `context`, thus giving
298288
# precedence to the outmost `ConditionContext`.
299289
return ConditionContext(merge(context.values, values), childcontext(context))
@@ -303,7 +293,7 @@ function Base.show(io::IO, context::ConditionContext)
303293
return print(io, "ConditionContext($(context.values), $(childcontext(context)))")
304294
end
305295

306-
NodeTrait(context::ConditionContext) = IsParent()
296+
NodeTrait(::ConditionContext) = IsParent()
307297
childcontext(context::ConditionContext) = context.context
308298
setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child)
309299

@@ -313,14 +303,9 @@ setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.value
313303
Return `true` if `vn` is found in `context`.
314304
"""
315305
hasvalue(context, vn) = false
316-
317-
function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
318-
return sym in vars
319-
end
320-
function hasvalue(
321-
context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}
322-
) where {vars,sym}
323-
return sym in vars
306+
hasvalue(context::ConditionContext, vn::VarName) = nested_haskey(context.values, vn)
307+
function hasvalue(context::ConditionContext, vns::AbstractArray{<:VarName})
308+
return all(Base.Fix1(nested_haskey, context.values), vns)
324309
end
325310

326311
"""
@@ -331,7 +316,8 @@ Return value of `vn` in `context`.
331316
function getvalue(context::AbstractContext, vn)
332317
return error("context $(context) does not contain value for $vn")
333318
end
334-
getvalue(context::ConditionContext, vn) = get(context.values, vn)
319+
getvalue(context::NamedConditionContext, vn) = get(context.values, vn)
320+
getvalue(context::ConditionContext, vn) = nested_getindex(context.values, vn)
335321

336322
"""
337323
hasvalue_nested(context, vn)
@@ -386,15 +372,33 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default.
386372
387373
See also: [`decondition`](@ref)
388374
"""
389-
AbstractPPL.condition(; values...) = condition(DefaultContext(), NamedTuple(values))
375+
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
390376
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
377+
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
378+
return condition((value, values...))
379+
end
380+
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
381+
return condition(DefaultContext(), values)
382+
end
391383
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
392-
function AbstractPPL.condition(context::AbstractContext, values::NamedTuple)
384+
function AbstractPPL.condition(
385+
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
386+
)
393387
return ConditionContext(values, context)
394388
end
395389
function AbstractPPL.condition(context::AbstractContext; values...)
396390
return condition(context, NamedTuple(values))
397391
end
392+
function AbstractPPL.condition(
393+
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
394+
)
395+
return condition(context, (value, values...))
396+
end
397+
function AbstractPPL.condition(
398+
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
399+
)
400+
return condition(context, Dict(values))
401+
end
398402

399403
"""
400404
decondition(context::AbstractContext, syms...)
@@ -430,6 +434,19 @@ function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
430434
)
431435
end
432436

437+
function AbstractPPL.decondition(
438+
context::NamedConditionContext, vn::VarName{sym}
439+
) where {sym}
440+
return condition(
441+
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
442+
)
443+
end
444+
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
445+
return condition(
446+
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
447+
)
448+
end
449+
433450
"""
434451
conditioned(context::AbstractContext)
435452

0 commit comments

Comments
 (0)