Skip to content

Commit bc1625a

Browse files
Refactor table initialization
1 parent 4b18d2a commit bc1625a

File tree

1 file changed

+43
-39
lines changed

1 file changed

+43
-39
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -257,45 +257,10 @@ def __init__(
257257
self.kalman_smoother = KalmanSmoother()
258258
self.make_symbolic_graph()
259259

260-
self.requirement_table = Table(
261-
show_header=True,
262-
show_edge=True,
263-
box=SIMPLE_HEAD,
264-
highlight=True,
265-
)
266-
267-
self.requirement_table.title = "Model Requirements"
268-
self.requirement_table.caption = (
269-
"These parameters should be assigned priors inside a PyMC model block before "
270-
"calling the build_statespace_graph method."
271-
)
260+
self._populate_prior_requirements()
261+
self._populate_data_requirements()
272262

273-
self.requirement_table.add_column("Variable", justify="left")
274-
self.requirement_table.add_column("Shape", justify="left")
275-
self.requirement_table.add_column("Constraints", justify="left")
276-
self.requirement_table.add_column("Dimensions", justify="right")
277-
278-
has_prior_info = False
279-
has_data_info = False
280-
try:
281-
self.param_info
282-
has_prior_info = True
283-
except NotImplementedError:
284-
pass
285-
286-
try:
287-
self.data_info
288-
has_data_info = True
289-
except NotImplementedError:
290-
pass
291-
292-
if has_prior_info:
293-
self._populate_prior_requirements()
294-
295-
if has_data_info:
296-
self._populate_data_requirements()
297-
298-
if verbose and (has_prior_info or has_data_info):
263+
if verbose and self.requirement_table:
299264
console = Console()
300265
console.print(self.requirement_table)
301266

@@ -304,6 +269,17 @@ def _populate_prior_requirements(self) -> None:
304269
Add requirements about priors needed for the model to a rich table, including their names,
305270
shapes, named dimensions, and any parameter constraints.
306271
"""
272+
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
273+
# is not true.
274+
try:
275+
if not isinstance(self.param_info, dict):
276+
return
277+
except NotImplementedError:
278+
return
279+
280+
if self.requirement_table is None:
281+
self._initialize_requirement_table()
282+
307283
for param, info in self.param_info.items():
308284
self.requirement_table.add_row(
309285
param, str(info["shape"]), info["constraints"], str(info["dims"])
@@ -313,11 +289,39 @@ def _populate_data_requirements(self) -> None:
313289
"""
314290
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
315291
"""
316-
self.requirement_table.add_section()
292+
try:
293+
if not isinstance(self.data_info, dict):
294+
return
295+
except NotImplementedError:
296+
return
297+
298+
if self.requirement_table is None:
299+
self._initialize_requirement_table()
300+
else:
301+
self.requirement_table.add_section()
317302

318303
for data, info in self.data_info.items():
319304
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
320305

306+
def _initialize_requirement_table(self) -> None:
307+
self.requirement_table = Table(
308+
show_header=True,
309+
show_edge=True,
310+
box=SIMPLE_HEAD,
311+
highlight=True,
312+
)
313+
314+
self.requirement_table.title = "Model Requirements"
315+
self.requirement_table.caption = (
316+
"These parameters should be assigned priors inside a PyMC model block before "
317+
"calling the build_statespace_graph method."
318+
)
319+
320+
self.requirement_table.add_column("Variable", justify="left")
321+
self.requirement_table.add_column("Shape", justify="left")
322+
self.requirement_table.add_column("Constraints", justify="left")
323+
self.requirement_table.add_column("Dimensions", justify="right")
324+
321325
def _unpack_statespace_with_placeholders(
322326
self,
323327
) -> tuple[

0 commit comments

Comments
 (0)