@@ -257,45 +257,10 @@ def __init__(
257
257
self .kalman_smoother = KalmanSmoother ()
258
258
self .make_symbolic_graph ()
259
259
260
- self .requirement_table = Table (
261
- show_header = True ,
262
- show_edge = True ,
263
- box = SIMPLE_HEAD ,
264
- highlight = True ,
265
- )
266
-
267
- self .requirement_table .title = "Model Requirements"
268
- self .requirement_table .caption = (
269
- "These parameters should be assigned priors inside a PyMC model block before "
270
- "calling the build_statespace_graph method."
271
- )
260
+ self ._populate_prior_requirements ()
261
+ self ._populate_data_requirements ()
272
262
273
- self .requirement_table .add_column ("Variable" , justify = "left" )
274
- self .requirement_table .add_column ("Shape" , justify = "left" )
275
- self .requirement_table .add_column ("Constraints" , justify = "left" )
276
- self .requirement_table .add_column ("Dimensions" , justify = "right" )
277
-
278
- has_prior_info = False
279
- has_data_info = False
280
- try :
281
- self .param_info
282
- has_prior_info = True
283
- except NotImplementedError :
284
- pass
285
-
286
- try :
287
- self .data_info
288
- has_data_info = True
289
- except NotImplementedError :
290
- pass
291
-
292
- if has_prior_info :
293
- self ._populate_prior_requirements ()
294
-
295
- if has_data_info :
296
- self ._populate_data_requirements ()
297
-
298
- if verbose and (has_prior_info or has_data_info ):
263
+ if verbose and self .requirement_table :
299
264
console = Console ()
300
265
console .print (self .requirement_table )
301
266
@@ -304,6 +269,17 @@ def _populate_prior_requirements(self) -> None:
304
269
Add requirements about priors needed for the model to a rich table, including their names,
305
270
shapes, named dimensions, and any parameter constraints.
306
271
"""
272
+ # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
273
+ # is not true.
274
+ try :
275
+ if not isinstance (self .param_info , dict ):
276
+ return
277
+ except NotImplementedError :
278
+ return
279
+
280
+ if self .requirement_table is None :
281
+ self ._initialize_requirement_table ()
282
+
307
283
for param , info in self .param_info .items ():
308
284
self .requirement_table .add_row (
309
285
param , str (info ["shape" ]), info ["constraints" ], str (info ["dims" ])
@@ -313,11 +289,39 @@ def _populate_data_requirements(self) -> None:
313
289
"""
314
290
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
315
291
"""
316
- self .requirement_table .add_section ()
292
+ try :
293
+ if not isinstance (self .data_info , dict ):
294
+ return
295
+ except NotImplementedError :
296
+ return
297
+
298
+ if self .requirement_table is None :
299
+ self ._initialize_requirement_table ()
300
+ else :
301
+ self .requirement_table .add_section ()
317
302
318
303
for data , info in self .data_info .items ():
319
304
self .requirement_table .add_row (data , str (info ["shape" ]), "pm.Data" , str (info ["dims" ]))
320
305
306
+ def _initialize_requirement_table (self ) -> None :
307
+ self .requirement_table = Table (
308
+ show_header = True ,
309
+ show_edge = True ,
310
+ box = SIMPLE_HEAD ,
311
+ highlight = True ,
312
+ )
313
+
314
+ self .requirement_table .title = "Model Requirements"
315
+ self .requirement_table .caption = (
316
+ "These parameters should be assigned priors inside a PyMC model block before "
317
+ "calling the build_statespace_graph method."
318
+ )
319
+
320
+ self .requirement_table .add_column ("Variable" , justify = "left" )
321
+ self .requirement_table .add_column ("Shape" , justify = "left" )
322
+ self .requirement_table .add_column ("Constraints" , justify = "left" )
323
+ self .requirement_table .add_column ("Dimensions" , justify = "right" )
324
+
321
325
def _unpack_statespace_with_placeholders (
322
326
self ,
323
327
) -> tuple [
0 commit comments