Skip to content

Commit 824f712

Browse files
Added check_model and sub-module DebugUtils (#540)
* initial work on model checking * use record_pre_tilde!, record_post_tilde!, etc. instead of just a single record_tilde! + support for dot tilde + return issuccess and additional info in check_model * added test_context_interface to TestUtils * added tests for check_model * moved debug contexts and check_model to a separate file * export check_model + make DebugContext take the model as input so we can further customize * noticed I forgot to include check_models.jl file * fixed tests * added record-methods for observe statements too * use explicit types for the recorded tilde statements + added convenient show methods to make displaying the trace nicer * renamd check__model to debug_utils and put it into a module * renamed test/check_model.jl to test/debug_utils.jl * removed unnecessary stuff in tests * added test for logging of statements * removed unnecessary splatting in broadcasting + improved errors for encountering missing * added missing implementation of tilde_observe for PrefixContext * re-ordered method implementations for DebugContext to make things a bit more readable * addeed error message indicating that usage of missing for de-conditioning is restricted to univariate distributions * added missing left field to ObserveStmt * fixed conditioned * fixed `fixed` too, and moved the `_merge` to a more sensible location * added check_model_post_evaluation and made it so we're using SamplingContext by default since we're using an empty VarInfo by default * removed show_statements * perform some simple checks to make sure show is working for statements * improved test for show of statements a tiny bit * added some more docs * more docs * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed typo in warning * moved inclusion of trace and others in return-value from check_model to check_model_and_extras * formatting * drop returning varnames_seen and renamed check_model_and_extras to check_model_and_trace * drop export of DebugContext * added check_model and check_model_and_trace to docs * updated tests * more updates to tests * formatting * added rng as an optional positional argument to check_model methods * added an example in the docstring of check_model_and_trace * added example of correct and incorrect model in check_model_and_trace docstring * Update src/debug_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed docs maybe * fixed docstring maybe * fixed reference to Setfield and tests * fixed docs * added some conveinence methods in addition to a `has_static_constraints` method to empirically check whether the model has static constraints or if they are indeed changing dependent on realizations * improved show for large arrays of varnames whiich can occur in dot-tilde statements --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 33a84c7 commit 824f712

File tree

8 files changed

+815
-6
lines changed

8 files changed

+815
-6
lines changed

docs/src/api.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ DynamicPPL.TestUtils.update_values!!
187187
DynamicPPL.TestUtils.test_values
188188
```
189189

190+
## Debugging Utilities
191+
192+
DynamicPPL provides a few methods for checking validity of a model-definition.
193+
194+
```@docs
195+
check_model
196+
check_model_and_trace
197+
```
198+
199+
And some which might be useful to determine certain properties of the model based on the debug trace.
200+
201+
```@docs
202+
DynamicPPL.has_static_constraints
203+
```
204+
190205
## Advanced
191206

192207
### Variable names

src/DynamicPPL.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ export AbstractVarInfo,
8383
vectorize,
8484
reconstruct,
8585
reconstruct!,
86-
Sample,
8786
init,
8887
vectorize,
8988
OrderedDict,
@@ -130,7 +129,9 @@ export AbstractVarInfo,
130129
# Convenience macros
131130
@addlogprob!,
132131
@submodel,
133-
value_iterator_from_chain
132+
value_iterator_from_chain,
133+
check_model,
134+
check_model_and_trace
134135

135136
# Reexport
136137
using Distributions: loglikelihood
@@ -179,6 +180,9 @@ include("logdensityfunction.jl")
179180
include("model_utils.jl")
180181
include("extract_priors.jl")
181182

183+
include("debug_utils.jl")
184+
using .DebugUtils
185+
182186
if !isdefined(Base, :get_extension)
183187
using Requires
184188
end

src/context_implementations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ end
177177
function tilde_observe(context::PrefixContext, right, left, vi)
178178
return tilde_observe(context.context, right, left, vi)
179179
end
180+
function tilde_observe(context::PrefixContext, sampler, right, left, vi)
181+
return tilde_observe(context.context, sampler, right, left, vi)
182+
end
180183

181184
"""
182185
tilde_observe!!(context, right, left, vname, vi)

src/contexts.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,14 @@ a merged version of the condition values.
479479
function conditioned(context::AbstractContext)
480480
return conditioned(NodeTrait(conditioned, context), context)
481481
end
482-
conditioned(::IsLeaf, context) = ()
482+
conditioned(::IsLeaf, context) = NamedTuple()
483483
conditioned(::IsParent, context) = conditioned(childcontext(context))
484484
function conditioned(context::ConditionContext)
485485
# Note the order of arguments to `merge`. The behavior of the rest of DPPL
486486
# is that the outermost `context` takes precendence, hence when resolving
487487
# the `conditioned` variables we need to ensure that `context.values` takes
488488
# precedence over decendants of `context`.
489-
return merge(context.values, conditioned(childcontext(context)))
489+
return _merge(context.values, conditioned(childcontext(context)))
490490
end
491491

492492
struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext
@@ -655,12 +655,12 @@ Note that this will recursively traverse the context stack and return
655655
a merged version of the fix values.
656656
"""
657657
fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context)
658-
fixed(::IsLeaf, context) = ()
658+
fixed(::IsLeaf, context) = NamedTuple()
659659
fixed(::IsParent, context) = fixed(childcontext(context))
660660
function fixed(context::FixedContext)
661661
# Note the order of arguments to `merge`. The behavior of the rest of DPPL
662662
# is that the outermost `context` takes precendence, hence when resolving
663663
# the `fixed` variables we need to ensure that `context.values` takes
664664
# precedence over decendants of `context`.
665-
return merge(context.values, fixed(childcontext(context)))
665+
return _merge(context.values, fixed(childcontext(context)))
666666
end

0 commit comments

Comments
 (0)