@@ -251,3 +251,176 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
251
251
VarName {Symbol(Prefix, PREFIX_SEPARATOR, Sym)} (vn. indexing)
252
252
end
253
253
end
254
+
255
+ struct ConditionContext{Names,Values,Ctx<: AbstractContext } <: AbstractContext
256
+ values:: Values
257
+ context:: Ctx
258
+
259
+ function ConditionContext {Values} (
260
+ values:: Values , context:: AbstractContext
261
+ ) where {names,Values<: NamedTuple{names} }
262
+ return new {names,typeof(values),typeof(context)} (values, context)
263
+ end
264
+ end
265
+
266
+ function ConditionContext (values:: NamedTuple )
267
+ return ConditionContext (values, DefaultContext ())
268
+ end
269
+ function ConditionContext (values:: NamedTuple , context:: AbstractContext )
270
+ return ConditionContext {typeof(values)} (values, context)
271
+ end
272
+
273
+ # Try to avoid nested `ConditionContext`.
274
+ function ConditionContext (
275
+ values:: NamedTuple{Names} , context:: ConditionContext
276
+ ) where {Names}
277
+ # Note that this potentially overrides values from `context`, thus giving
278
+ # precedence to the outmost `ConditionContext`.
279
+ return ConditionContext (merge (context. values, values), childcontext (context))
280
+ end
281
+
282
+ function Base. show (io:: IO , context:: ConditionContext )
283
+ return print (io, " ConditionContext($(context. values) , $(childcontext (context)) )" )
284
+ end
285
+
286
+ NodeTrait (context:: ConditionContext ) = IsParent ()
287
+ childcontext (context:: ConditionContext ) = context. context
288
+ setchildcontext (parent:: ConditionContext , child) = ConditionContext (parent. values, child)
289
+
290
+ """
291
+ hasvalue(context, vn)
292
+
293
+ Return `true` if `vn` is found in `context`.
294
+ """
295
+ hasvalue (context, vn) = false
296
+
297
+ function hasvalue (context:: ConditionContext{vars} , vn:: VarName{sym} ) where {vars,sym}
298
+ return sym in vars
299
+ end
300
+ function hasvalue (
301
+ context:: ConditionContext{vars} , vn:: AbstractArray{<:VarName{sym}}
302
+ ) where {vars,sym}
303
+ return sym in vars
304
+ end
305
+
306
+ """
307
+ getvalue(context, vn)
308
+
309
+ Return value of `vn` in `context`.
310
+ """
311
+ function getvalue (context:: AbstractContext , vn)
312
+ return error (" context $(context) does not contain value for $vn " )
313
+ end
314
+ getvalue (context:: ConditionContext , vn) = _getvalue (context. values, vn)
315
+
316
+ """
317
+ hasvalue_nested(context, vn)
318
+
319
+ Return `true` if `vn` is found in `context` or any of its descendants.
320
+
321
+ This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`,
322
+ not recursively checking if `vn` is in any of its descendants.
323
+ """
324
+ function hasvalue_nested (context:: AbstractContext , vn)
325
+ return hasvalue_nested (NodeTrait (hasvalue_nested, context), context, vn)
326
+ end
327
+ hasvalue_nested (:: IsLeaf , context, vn) = hasvalue (context, vn)
328
+ function hasvalue_nested (:: IsParent , context, vn)
329
+ return hasvalue (context, vn) || hasvalue_nested (childcontext (context), vn)
330
+ end
331
+ function hasvalue_nested (context:: PrefixContext , vn)
332
+ return hasvalue_nested (childcontext (context), prefix (context, vn))
333
+ end
334
+
335
+ """
336
+ getvalue_nested(context, vn)
337
+
338
+ Return the value of the parameter corresponding to `vn` from `context` or its descendants.
339
+
340
+ This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`,
341
+ not recursively looking into its descendants.
342
+ """
343
+ function getvalue_nested (context:: AbstractContext , vn)
344
+ return getvalue_nested (NodeTrait (getvalue_nested, context), context, vn)
345
+ end
346
+ function getvalue_nested (:: IsLeaf , context, vn)
347
+ return error (" context $(context) does not contain value for $vn " )
348
+ end
349
+ function getvalue_nested (context:: PrefixContext , vn)
350
+ return getvalue_nested (childcontext (context), prefix (context, vn))
351
+ end
352
+ function getvalue_nested (:: IsParent , context, vn)
353
+ return if hasvalue (context, vn)
354
+ getvalue (context, vn)
355
+ else
356
+ getvalue_nested (childcontext (context), vn)
357
+ end
358
+ end
359
+
360
+ """
361
+ condition([context::AbstractContext,] values::NamedTuple)
362
+ condition([context::AbstractContext]; values...)
363
+
364
+ Return `ConditionContext` with `values` and `context` if `values` is non-empty,
365
+ otherwise return `context` which is [`DefaultContext`](@ref) by default.
366
+
367
+ See also: [`decondition`](@ref)
368
+ """
369
+ condition (; values... ) = condition (DefaultContext (), NamedTuple (values))
370
+ condition (values:: NamedTuple ) = condition (DefaultContext (), values)
371
+ condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
372
+ condition (context:: AbstractContext , values:: NamedTuple ) = ConditionContext (values, context)
373
+ condition (context:: AbstractContext ; values... ) = condition (context, NamedTuple (values))
374
+
375
+ """
376
+ decondition(context::AbstractContext, syms...)
377
+
378
+ Return `context` but with `syms` no longer conditioned on.
379
+
380
+ Note that this recursively traverses contexts, deconditioning all along the way.
381
+
382
+ See also: [`condition`](@ref)
383
+ """
384
+ decondition (:: IsLeaf , context, args... ) = context
385
+ function decondition (:: IsParent , context, args... )
386
+ return setchildcontext (context, decondition (childcontext (context), args... ))
387
+ end
388
+ decondition (context, args... ) = decondition (NodeTrait (context), context, args... )
389
+ function decondition (context:: ConditionContext )
390
+ return decondition (childcontext (context))
391
+ end
392
+ function decondition (context:: ConditionContext , sym)
393
+ return condition (
394
+ decondition (childcontext (context), sym), BangBang. delete!! (context. values, sym)
395
+ )
396
+ end
397
+ function decondition (context:: ConditionContext , sym, syms... )
398
+ return decondition (
399
+ condition (
400
+ decondition (childcontext (context), syms... ),
401
+ BangBang. delete!! (context. values, sym),
402
+ ),
403
+ syms... ,
404
+ )
405
+ end
406
+
407
+ """
408
+ conditioned(context::AbstractContext)
409
+
410
+ Return `NamedTuple` of values that are conditioned on under context`.
411
+
412
+ Note that this will recursively traverse the context stack and return
413
+ a merged version of the condition values.
414
+ """
415
+ function conditioned (context:: AbstractContext )
416
+ return conditioned (NodeTrait (conditioned, context), context)
417
+ end
418
+ conditioned (:: IsLeaf , context) = ()
419
+ conditioned (:: IsParent , context) = conditioned (childcontext (context))
420
+ function conditioned (context:: ConditionContext )
421
+ # Note the order of arguments to `merge`. The behavior of the rest of DPPL
422
+ # is that the outermost `context` takes precendence, hence when resolving
423
+ # the `conditioned` variables we need to ensure that `context.values` takes
424
+ # precedence over decendants of `context`.
425
+ return merge (context. values, conditioned (childcontext (context)))
426
+ end
0 commit comments