@@ -131,9 +131,7 @@ A context used for checking validity of a model.
131
131
# Fields
132
132
$(FIELDS)
133
133
"""
134
- struct DebugContext{M<: Model ,C<: AbstractContext } <: AbstractContext
135
- " model that is being run"
136
- model:: M
134
+ struct DebugContext{C<: AbstractContext } <: AbstractContext
137
135
" context used for running the model"
138
136
context:: C
139
137
" mapping from varnames to the number of times they have been seen"
@@ -149,7 +147,6 @@ struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext
149
147
end
150
148
151
149
function DebugContext (
152
- model:: Model ,
153
150
context:: AbstractContext = DefaultContext ();
154
151
varnames_seen= OrderedDict {VarName,Int} (),
155
152
statements= Vector {Stmt} (),
@@ -158,7 +155,6 @@ function DebugContext(
158
155
record_varinfo= false ,
159
156
)
160
157
return DebugContext (
161
- model,
162
158
context,
163
159
varnames_seen,
164
160
statements,
@@ -344,7 +340,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int})
344
340
end
345
341
346
342
# A check we run on the model before evaluating it.
347
- function check_model_pre_evaluation (context :: DebugContext , model:: Model )
343
+ function check_model_pre_evaluation (model:: Model )
348
344
issuccess = true
349
345
# If something is in the model arguments, then it should NOT be in `condition`,
350
346
# nor should there be any symbol present in `condition` that has the same symbol.
@@ -361,8 +357,8 @@ function check_model_pre_evaluation(context::DebugContext, model::Model)
361
357
return issuccess
362
358
end
363
359
364
- function check_model_post_evaluation (context :: DebugContext , model:: Model )
365
- return check_varnames_seen (context. varnames_seen)
360
+ function check_model_post_evaluation (model:: Model )
361
+ return check_varnames_seen (model . context. varnames_seen)
366
362
end
367
363
368
364
"""
@@ -443,21 +439,18 @@ function check_model_and_trace(
443
439
)
444
440
# Execute the model with the debug context.
445
441
debug_context = DebugContext (
446
- model,
447
- SamplingContext (rng, model. context);
448
- error_on_failure= error_on_failure,
449
- kwargs... ,
442
+ SamplingContext (rng, model. context); error_on_failure= error_on_failure, kwargs...
450
443
)
451
444
debug_model = DynamicPPL. contextualize (model, debug_context)
452
445
453
446
# Perform checks before evaluating the model.
454
- issuccess = check_model_pre_evaluation (debug_context, debug_model)
447
+ issuccess = check_model_pre_evaluation (debug_model)
455
448
456
449
# Force single-threaded execution.
457
450
DynamicPPL. evaluate_threadunsafe!! (debug_model, varinfo)
458
451
459
452
# Perform checks after evaluating the model.
460
- issuccess &= check_model_post_evaluation (debug_context, debug_model)
453
+ issuccess &= check_model_post_evaluation (debug_model)
461
454
462
455
if ! issuccess && error_on_failure
463
456
error (" model check failed" )
0 commit comments