Skip to content

Commit 7e3bb32

Browse files
Use rich table in build report
1 parent dcc2bec commit 7e3bb32

File tree

1 file changed

+41
-36
lines changed

1 file changed

+41
-36
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from pymc.util import RandomState
1616
from pytensor import Variable, graph_replace
1717
from pytensor.compile import get_mode
18+
from rich.box import SIMPLE_HEAD
19+
from rich.console import Console
20+
from rich.table import Table
1821

1922
from pymc_extras.statespace.core.representation import PytensorRepresentation
2023
from pymc_extras.statespace.filters import (
@@ -254,52 +257,54 @@ def __init__(
254257
self.kalman_smoother = KalmanSmoother()
255258
self.make_symbolic_graph()
256259

260+
self.requirement_table = Table(
261+
show_header=True,
262+
show_edge=True,
263+
box=SIMPLE_HEAD,
264+
highlight=True,
265+
)
266+
267+
self.requirement_table.title = "Model Requirements"
268+
self.requirement_table.caption = (
269+
"These parameters should be assigned priors inside a PyMC model block before "
270+
"calling the build_statespace_graph method."
271+
)
272+
273+
self.requirement_table.add_column("Variable", justify="right")
274+
self.requirement_table.add_column("Shape", justify="left")
275+
self.requirement_table.add_column("Constraints", justify="left")
276+
self.requirement_table.add_column("Dimensions", justify="right")
277+
278+
self._populate_prior_requirements()
279+
self._populate_data_requirements()
280+
257281
if verbose:
258-
# These are split into separate try-except blocks, because it will be quite rare of models to implement
259-
# _print_data_requirements, but we still want to print the prior requirements.
260-
try:
261-
self._print_prior_requirements()
262-
except NotImplementedError:
263-
pass
264-
try:
265-
self._print_data_requirements()
266-
except NotImplementedError:
267-
pass
268-
269-
def _print_prior_requirements(self) -> None:
270-
"""
271-
Prints a short report to the terminal about the priors needed for the model, including their names,
282+
console = Console()
283+
console.print(self.requirement_table)
284+
285+
def _populate_prior_requirements(self) -> None:
286+
"""
287+
Add requirements about priors needed for the model to a rich table, including their names,
272288
shapes, named dimensions, and any parameter constraints.
273289
"""
274-
out = ""
275290
for param, info in self.param_info.items():
276-
out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
277-
out = out.rstrip()
278-
279-
_log.info(
280-
"The following parameters should be assigned priors inside a PyMC "
281-
f"model block: \n"
282-
f"{out}"
283-
)
291+
self.requirement_table.add_row(
292+
param, str(info["shape"]), info["constraints"], str(info["dims"])
293+
)
284294

285-
def _print_data_requirements(self) -> None:
295+
def _populate_data_requirements(self) -> None:
286296
"""
287-
Prints a short report to the terminal about the data needed for the model, including their names, shapes,
288-
and named dimensions.
297+
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
289298
"""
290-
if not self.data_info:
299+
try:
300+
self.data_info
301+
except NotImplementedError:
291302
return
292303

293-
out = ""
294-
for data, info in self.data_info.items():
295-
out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
296-
out = out.rstrip()
304+
self.requirement_table.add_section()
297305

298-
_log.info(
299-
"The following Data variables should be assigned to the model inside a PyMC "
300-
f"model block: \n"
301-
f"{out}"
302-
)
306+
for data, info in self.data_info.items():
307+
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
303308

304309
def _unpack_statespace_with_placeholders(
305310
self,

0 commit comments

Comments
 (0)