@@ -272,28 +272,18 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
272
272
end
273
273
end
274
274
275
- struct ConditionContext{Names, Values,Ctx<: AbstractContext } <: AbstractContext
275
+ struct ConditionContext{Values,Ctx<: AbstractContext } <: AbstractContext
276
276
values:: Values
277
277
context:: Ctx
278
-
279
- function ConditionContext {Values} (
280
- values:: Values , context:: AbstractContext
281
- ) where {names,Values<: NamedTuple{names} }
282
- return new {names,typeof(values),typeof(context)} (values, context)
283
- end
284
278
end
285
279
286
- function ConditionContext (values:: NamedTuple )
287
- return ConditionContext (values, DefaultContext ())
288
- end
289
- function ConditionContext (values:: NamedTuple , context:: AbstractContext )
290
- return ConditionContext {typeof(values)} (values, context)
291
- end
280
+ const NamedConditionContext{Names} = ConditionContext{<: NamedTuple{Names} }
281
+ const DictConditionContext = ConditionContext{<: AbstractDict }
282
+
283
+ ConditionContext (values) = ConditionContext (values, DefaultContext ())
292
284
293
285
# Try to avoid nested `ConditionContext`.
294
- function ConditionContext (
295
- values:: NamedTuple{Names} , context:: ConditionContext
296
- ) where {Names}
286
+ function ConditionContext (values:: NamedTuple , context:: NamedConditionContext )
297
287
# Note that this potentially overrides values from `context`, thus giving
298
288
# precedence to the outmost `ConditionContext`.
299
289
return ConditionContext (merge (context. values, values), childcontext (context))
@@ -303,7 +293,7 @@ function Base.show(io::IO, context::ConditionContext)
303
293
return print (io, " ConditionContext($(context. values) , $(childcontext (context)) )" )
304
294
end
305
295
306
- NodeTrait (context :: ConditionContext ) = IsParent ()
296
+ NodeTrait (:: ConditionContext ) = IsParent ()
307
297
childcontext (context:: ConditionContext ) = context. context
308
298
setchildcontext (parent:: ConditionContext , child) = ConditionContext (parent. values, child)
309
299
@@ -313,14 +303,9 @@ setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.value
313
303
Return `true` if `vn` is found in `context`.
314
304
"""
315
305
hasvalue (context, vn) = false
316
-
317
- function hasvalue (context:: ConditionContext{vars} , vn:: VarName{sym} ) where {vars,sym}
318
- return sym in vars
319
- end
320
- function hasvalue (
321
- context:: ConditionContext{vars} , vn:: AbstractArray{<:VarName{sym}}
322
- ) where {vars,sym}
323
- return sym in vars
306
+ hasvalue (context:: ConditionContext , vn:: VarName ) = nested_haskey (context. values, vn)
307
+ function hasvalue (context:: ConditionContext , vns:: AbstractArray{<:VarName} )
308
+ return all (Base. Fix1 (nested_haskey, context. values), vns)
324
309
end
325
310
326
311
"""
@@ -331,7 +316,8 @@ Return value of `vn` in `context`.
331
316
function getvalue (context:: AbstractContext , vn)
332
317
return error (" context $(context) does not contain value for $vn " )
333
318
end
334
- getvalue (context:: ConditionContext , vn) = get (context. values, vn)
319
+ getvalue (context:: NamedConditionContext , vn) = get (context. values, vn)
320
+ getvalue (context:: ConditionContext , vn) = nested_getindex (context. values, vn)
335
321
336
322
"""
337
323
hasvalue_nested(context, vn)
@@ -386,15 +372,33 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default.
386
372
387
373
See also: [`decondition`](@ref)
388
374
"""
389
- AbstractPPL. condition (; values... ) = condition (DefaultContext (), NamedTuple (values))
375
+ AbstractPPL. condition (; values... ) = condition (NamedTuple (values))
390
376
AbstractPPL. condition (values:: NamedTuple ) = condition (DefaultContext (), values)
377
+ function AbstractPPL. condition (value:: Pair{<:VarName} , values:: Pair{<:VarName} ...)
378
+ return condition ((value, values... ))
379
+ end
380
+ function AbstractPPL. condition (values:: NTuple{<:Any,<:Pair{<:VarName}} )
381
+ return condition (DefaultContext (), values)
382
+ end
391
383
AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
392
- function AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple )
384
+ function AbstractPPL. condition (
385
+ context:: AbstractContext , values:: Union{AbstractDict,NamedTuple}
386
+ )
393
387
return ConditionContext (values, context)
394
388
end
395
389
function AbstractPPL. condition (context:: AbstractContext ; values... )
396
390
return condition (context, NamedTuple (values))
397
391
end
392
+ function AbstractPPL. condition (
393
+ context:: AbstractContext , value:: Pair{<:VarName} , values:: Pair{<:VarName} ...
394
+ )
395
+ return condition (context, (value, values... ))
396
+ end
397
+ function AbstractPPL. condition (
398
+ context:: AbstractContext , values:: NTuple{<:Any,Pair{<:VarName}}
399
+ )
400
+ return condition (context, Dict (values))
401
+ end
398
402
399
403
"""
400
404
decondition(context::AbstractContext, syms...)
@@ -430,6 +434,19 @@ function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
430
434
)
431
435
end
432
436
437
+ function AbstractPPL. decondition (
438
+ context:: NamedConditionContext , vn:: VarName{sym}
439
+ ) where {sym}
440
+ return condition (
441
+ decondition (childcontext (context), vn), BangBang. delete!! (context. values, sym)
442
+ )
443
+ end
444
+ function AbstractPPL. decondition (context:: ConditionContext , vn:: VarName )
445
+ return condition (
446
+ decondition (childcontext (context), vn), BangBang. delete!! (context. values, vn)
447
+ )
448
+ end
449
+
433
450
"""
434
451
conditioned(context::AbstractContext)
435
452
0 commit comments