@@ -309,16 +309,31 @@ function prefix(model::Model, ::Val{x}) where {x}
309
309
return contextualize (model, PrefixContext {Symbol(x)} (model. context))
310
310
end
311
311
312
- struct ConditionContext{Values,Ctx<: AbstractContext } <: AbstractContext
312
+ """
313
+
314
+ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
315
+
316
+ Model context that contains values that are to be conditioned on. The values
317
+ can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
318
+ an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
319
+ @varname(b) => 2)`). The former is more performant, but the latter must be used
320
+ when there are varnames that cannot be represented as symbols, e.g.
321
+ `@varname(x[1])`.
322
+ """
323
+ struct ConditionContext{
324
+ Values<: Union{NamedTuple,AbstractDict{<:VarName}} ,Ctx<: AbstractContext
325
+ } <: AbstractContext
313
326
values:: Values
314
327
context:: Ctx
315
328
end
316
329
317
330
const NamedConditionContext{Names} = ConditionContext{<: NamedTuple{Names} }
318
331
const DictConditionContext = ConditionContext{<: AbstractDict }
319
332
320
- ConditionContext (values) = ConditionContext (values, DefaultContext ())
321
-
333
+ # Use DefaultContext as the default base context
334
+ ConditionContext (values:: Union{NamedTuple,AbstractDict} ) = ConditionContext (values, DefaultContext ())
335
+ # Optimisation when there are no values to condition on
336
+ ConditionContext (:: NamedTuple{()} , context:: AbstractContext ) = context
322
337
# Try to avoid nested `ConditionContext`.
323
338
function ConditionContext (values:: NamedTuple , context:: NamedConditionContext )
324
339
# Note that this potentially overrides values from `context`, thus giving
@@ -399,43 +414,6 @@ function getconditioned_nested(::IsParent, context, vn)
399
414
end
400
415
end
401
416
402
- """
403
- condition([context::AbstractContext,] values::NamedTuple)
404
- condition([context::AbstractContext]; values...)
405
-
406
- Return `ConditionContext` with `values` and `context` if `values` is non-empty,
407
- otherwise return `context` which is [`DefaultContext`](@ref) by default.
408
-
409
- See also: [`decondition`](@ref)
410
- """
411
- AbstractPPL. condition (; values... ) = condition (NamedTuple (values))
412
- AbstractPPL. condition (values:: NamedTuple ) = condition (DefaultContext (), values)
413
- function AbstractPPL. condition (value:: Pair{<:VarName} , values:: Pair{<:VarName} ...)
414
- return condition ((value, values... ))
415
- end
416
- function AbstractPPL. condition (values:: NTuple{<:Any,<:Pair{<:VarName}} )
417
- return condition (DefaultContext (), values)
418
- end
419
- AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
420
- function AbstractPPL. condition (
421
- context:: AbstractContext , values:: Union{AbstractDict,NamedTuple}
422
- )
423
- return ConditionContext (values, context)
424
- end
425
- function AbstractPPL. condition (context:: AbstractContext ; values... )
426
- return condition (context, NamedTuple (values))
427
- end
428
- function AbstractPPL. condition (
429
- context:: AbstractContext , value:: Pair{<:VarName} , values:: Pair{<:VarName} ...
430
- )
431
- return condition (context, (value, values... ))
432
- end
433
- function AbstractPPL. condition (
434
- context:: AbstractContext , values:: NTuple{<:Any,Pair{<:VarName}}
435
- )
436
- return condition (context, Dict (values))
437
- end
438
-
439
417
"""
440
418
decondition(context::AbstractContext, syms...)
441
419
@@ -445,41 +423,41 @@ Note that this recursively traverses contexts, deconditioning all along the way.
445
423
446
424
See also: [`condition`](@ref)
447
425
"""
448
- AbstractPPL . decondition (:: IsLeaf , context, args... ) = context
449
- function AbstractPPL . decondition (:: IsParent , context, args... )
450
- return setchildcontext (context, decondition (childcontext (context), args... ))
426
+ decondition_context (:: IsLeaf , context, args... ) = context
427
+ function decondition_context (:: IsParent , context, args... )
428
+ return setchildcontext (context, decondition_context (childcontext (context), args... ))
451
429
end
452
- function AbstractPPL . decondition (context, args... )
453
- return decondition (NodeTrait (context), context, args... )
430
+ function decondition_context (context, args... )
431
+ return decondition_context (NodeTrait (context), context, args... )
454
432
end
455
- function AbstractPPL . decondition (context:: ConditionContext )
456
- return decondition (childcontext (context))
433
+ function decondition_context (context:: ConditionContext )
434
+ return decondition_context (childcontext (context))
457
435
end
458
- function AbstractPPL . decondition (context:: ConditionContext , sym)
459
- return condition (
460
- decondition (childcontext (context), sym), BangBang. delete!! (context. values, sym)
436
+ function decondition_context (context:: ConditionContext , sym)
437
+ return ConditionContext (
438
+ decondition_context (childcontext (context), sym), BangBang. delete!! (context. values, sym)
461
439
)
462
440
end
463
- function AbstractPPL . decondition (context:: ConditionContext , sym, syms... )
464
- return decondition (
465
- condition (
466
- decondition (childcontext (context), syms... ),
441
+ function decondition_context (context:: ConditionContext , sym, syms... )
442
+ return decondition_context (
443
+ ConditionContext (
444
+ decondition_context (childcontext (context), syms... ),
467
445
BangBang. delete!! (context. values, sym),
468
446
),
469
447
syms... ,
470
448
)
471
449
end
472
450
473
- function AbstractPPL . decondition (
451
+ function decondition_context (
474
452
context:: NamedConditionContext , vn:: VarName{sym}
475
453
) where {sym}
476
- return condition (
477
- decondition (childcontext (context), vn), BangBang. delete!! (context. values, sym)
454
+ return ConditionContext (
455
+ decondition_context (childcontext (context), vn), BangBang. delete!! (context. values, sym)
478
456
)
479
457
end
480
- function AbstractPPL . decondition (context:: ConditionContext , vn:: VarName )
481
- return condition (
482
- decondition (childcontext (context), vn), BangBang. delete!! (context. values, vn)
458
+ function decondition_context (context:: ConditionContext , vn:: VarName )
459
+ return ConditionContext (
460
+ decondition_context (childcontext (context), vn), BangBang. delete!! (context. values, vn)
483
461
)
484
462
end
485
463
0 commit comments