|
15 | 15 | from pymc.util import RandomState
|
16 | 16 | from pytensor import Variable, graph_replace
|
17 | 17 | from pytensor.compile import get_mode
|
| 18 | +from rich.box import SIMPLE_HEAD |
| 19 | +from rich.console import Console |
| 20 | +from rich.table import Table |
18 | 21 |
|
19 | 22 | from pymc_extras.statespace.core.representation import PytensorRepresentation
|
20 | 23 | from pymc_extras.statespace.filters import (
|
@@ -254,52 +257,54 @@ def __init__(
|
254 | 257 | self.kalman_smoother = KalmanSmoother()
|
255 | 258 | self.make_symbolic_graph()
|
256 | 259 |
|
| 260 | + self.requirement_table = Table( |
| 261 | + show_header=True, |
| 262 | + show_edge=True, |
| 263 | + box=SIMPLE_HEAD, |
| 264 | + highlight=True, |
| 265 | + ) |
| 266 | + |
| 267 | + self.requirement_table.title = "Model Requirements" |
| 268 | + self.requirement_table.caption = ( |
| 269 | + "These parameters should be assigned priors inside a PyMC model block before " |
| 270 | + "calling the build_statespace_graph method." |
| 271 | + ) |
| 272 | + |
| 273 | + self.requirement_table.add_column("Variable", justify="right") |
| 274 | + self.requirement_table.add_column("Shape", justify="left") |
| 275 | + self.requirement_table.add_column("Constraints", justify="left") |
| 276 | + self.requirement_table.add_column("Dimensions", justify="right") |
| 277 | + |
| 278 | + self._populate_prior_requirements() |
| 279 | + self._populate_data_requirements() |
| 280 | + |
257 | 281 | if verbose:
|
258 |
| - # These are split into separate try-except blocks, because it will be quite rare of models to implement |
259 |
| - # _print_data_requirements, but we still want to print the prior requirements. |
260 |
| - try: |
261 |
| - self._print_prior_requirements() |
262 |
| - except NotImplementedError: |
263 |
| - pass |
264 |
| - try: |
265 |
| - self._print_data_requirements() |
266 |
| - except NotImplementedError: |
267 |
| - pass |
268 |
| - |
269 |
| - def _print_prior_requirements(self) -> None: |
270 |
| - """ |
271 |
| - Prints a short report to the terminal about the priors needed for the model, including their names, |
| 282 | + console = Console() |
| 283 | + console.print(self.requirement_table) |
| 284 | + |
| 285 | + def _populate_prior_requirements(self) -> None: |
| 286 | + """ |
| 287 | + Add requirements about priors needed for the model to a rich table, including their names, |
272 | 288 | shapes, named dimensions, and any parameter constraints.
|
273 | 289 | """
|
274 |
| - out = "" |
275 | 290 | for param, info in self.param_info.items():
|
276 |
| - out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n' |
277 |
| - out = out.rstrip() |
278 |
| - |
279 |
| - _log.info( |
280 |
| - "The following parameters should be assigned priors inside a PyMC " |
281 |
| - f"model block: \n" |
282 |
| - f"{out}" |
283 |
| - ) |
| 291 | + self.requirement_table.add_row( |
| 292 | + param, str(info["shape"]), info["constraints"], str(info["dims"]) |
| 293 | + ) |
284 | 294 |
|
285 |
| - def _print_data_requirements(self) -> None: |
| 295 | + def _populate_data_requirements(self) -> None: |
286 | 296 | """
|
287 |
| - Prints a short report to the terminal about the data needed for the model, including their names, shapes, |
288 |
| - and named dimensions. |
| 297 | + Add requirements about the data needed for the model, including their names, shapes, and named dimensions. |
289 | 298 | """
|
290 |
| - if not self.data_info: |
| 299 | + try: |
| 300 | + self.data_info |
| 301 | + except NotImplementedError: |
291 | 302 | return
|
292 | 303 |
|
293 |
| - out = "" |
294 |
| - for data, info in self.data_info.items(): |
295 |
| - out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n' |
296 |
| - out = out.rstrip() |
| 304 | + self.requirement_table.add_section() |
297 | 305 |
|
298 |
| - _log.info( |
299 |
| - "The following Data variables should be assigned to the model inside a PyMC " |
300 |
| - f"model block: \n" |
301 |
| - f"{out}" |
302 |
| - ) |
| 306 | + for data, info in self.data_info.items(): |
| 307 | + self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"])) |
303 | 308 |
|
304 | 309 | def _unpack_statespace_with_placeholders(
|
305 | 310 | self,
|
|
0 commit comments