Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,243 changes: 751 additions & 492 deletions notebooks/Exponential Trend Smoothing.ipynb

Large diffs are not rendered by default.

578 changes: 424 additions & 154 deletions notebooks/Making a Custom Statespace Model.ipynb

Large diffs are not rendered by default.

1,620 changes: 922 additions & 698 deletions notebooks/SARMA Example.ipynb

Large diffs are not rendered by default.

1,546 changes: 966 additions & 580 deletions notebooks/Structural Timeseries Modeling.ipynb

Large diffs are not rendered by default.

656 changes: 350 additions & 306 deletions notebooks/VARMAX Example.ipynb

Large diffs are not rendered by default.

92 changes: 57 additions & 35 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from pymc.util import RandomState
from pytensor import Variable, graph_replace
from pytensor.compile import get_mode
from rich.box import SIMPLE_HEAD
from rich.console import Console
from rich.table import Table

from pymc_extras.statespace.core.representation import PytensorRepresentation
from pymc_extras.statespace.filters import (
Expand Down Expand Up @@ -254,53 +257,72 @@ def __init__(
self.kalman_smoother = KalmanSmoother()
self.make_symbolic_graph()

if verbose:
# These are split into separate try-except blocks, because it will be quite rare of models to implement
# _print_data_requirements, but we still want to print the prior requirements.
try:
self._print_prior_requirements()
except NotImplementedError:
pass
try:
self._print_data_requirements()
except NotImplementedError:
pass

def _print_prior_requirements(self) -> None:
"""
Prints a short report to the terminal about the priors needed for the model, including their names,
self.requirement_table = None
self._populate_prior_requirements()
self._populate_data_requirements()

if verbose and self.requirement_table:
console = Console()
console.print(self.requirement_table)

def _populate_prior_requirements(self) -> None:
"""
Add requirements about priors needed for the model to a rich table, including their names,
shapes, named dimensions, and any parameter constraints.
"""
out = ""
for param, info in self.param_info.items():
out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
out = out.rstrip()
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
# is not true.
try:
if not isinstance(self.param_info, dict):
return
except NotImplementedError:
return

_log.info(
"The following parameters should be assigned priors inside a PyMC "
f"model block: \n"
f"{out}"
)
if self.requirement_table is None:
self._initialize_requirement_table()

def _print_data_requirements(self) -> None:
for param, info in self.param_info.items():
self.requirement_table.add_row(
param, str(info["shape"]), info["constraints"], str(info["dims"])
)

def _populate_data_requirements(self) -> None:
"""
Prints a short report to the terminal about the data needed for the model, including their names, shapes,
and named dimensions.
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
"""
if not self.data_info:
try:
if not isinstance(self.data_info, dict):
return
except NotImplementedError:
return

out = ""
if self.requirement_table is None:
self._initialize_requirement_table()
else:
self.requirement_table.add_section()

for data, info in self.data_info.items():
out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
out = out.rstrip()
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))

def _initialize_requirement_table(self) -> None:
self.requirement_table = Table(
show_header=True,
show_edge=True,
box=SIMPLE_HEAD,
highlight=True,
)

_log.info(
"The following Data variables should be assigned to the model inside a PyMC "
f"model block: \n"
f"{out}"
self.requirement_table.title = "Model Requirements"
self.requirement_table.caption = (
"These parameters should be assigned priors inside a PyMC model block before "
"calling the build_statespace_graph method."
)

self.requirement_table.add_column("Variable", justify="left")
self.requirement_table.add_column("Shape", justify="left")
self.requirement_table.add_column("Constraints", justify="left")
self.requirement_table.add_column("Dimensions", justify="right")

def _unpack_statespace_with_placeholders(
self,
) -> tuple[
Expand Down
Loading