From a4dacd852b04fccd8b70c0733dc91fde75d661dc Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sun, 2 Nov 2025 10:47:22 -0700 Subject: [PATCH 1/8] proposal for updating propogate_component_properties using data classes --- .../structural_components_dataclass.ipynb | 583 ++++++++++++++++++ .../statespace/models/structural/__init__.py | 4 + .../components/regression_dataclass.py | 539 ++++++++++++++++ .../statespace/models/structural/core.py | 4 +- 4 files changed, 1128 insertions(+), 2 deletions(-) create mode 100644 notebooks/structural_components_dataclass.ipynb create mode 100644 pymc_extras/statespace/models/structural/components/regression_dataclass.py diff --git a/notebooks/structural_components_dataclass.ipynb b/notebooks/structural_components_dataclass.ipynb new file mode 100644 index 000000000..611d76767 --- /dev/null +++ b/notebooks/structural_components_dataclass.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ab70a522", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + } + ], + "source": [ + "from pymc_extras.statespace.models.structural import (\n", + " RegressionComponent,\n", + " RegressionComponentDataClass,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "17021aa3", + "metadata": {}, + "outputs": [], + "source": [ + "# Current way\n", + "reg = RegressionComponent(\n", + " name=\"regression\",\n", + " state_names=[\"a\", \"b\"],\n", + " observed_state_names=[\"y\"],\n", + " innovations=True,\n", + " share_states=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "219eb5da", + "metadata": {}, + "outputs": [], + "source": [ + "# Proposed way\n", + "reg_dataclass = RegressionComponentDataClass(\n", + " name=\"regression\",\n", + " state_names=[\"a\", \"b\"],\n", + " observed_state_names=[\"y\"],\n", + " innovations=True,\n", + " share_states=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7ff76653", + "metadata": {}, + "source": [ + "# Reminder of current implementation" + ] + }, + { + "cell_type": "markdown", + "id": "c05f86f6", + "metadata": {}, + "source": [ + "Currently state names are a list of string that only contain the names of the states" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7e37e574", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a[regression_shared]', 'b[regression_shared]']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.state_names" + ] + }, + { + "cell_type": "markdown", + "id": "0d484b59", + "metadata": {}, + "source": [ + "In the proposed dataclass implementation each state is a `StateProperty` and all the states are `StateProporties` dataclasses." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dee62a66", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "states: ['a[regression_shared]', 'b[regression_shared]']\n", + "observed: [True, True]\n" + ] + } + ], + "source": [ + "print(reg_dataclass.state_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cebd72af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: a[regression_shared]\n", + "observed: True\n", + "shared: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.state_names[\"a[regression_shared]\"]) # state name is the key" + ] + }, + { + "cell_type": "markdown", + "id": "1b8690a1", + "metadata": {}, + "source": [ + "Similarly with shock names we now have a shock_info that is a `ShockProperties` dataclass composed of `ShockProperty` dataclasses" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1320adac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a_shared', 'b_shared']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.shock_names" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6c905946", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shocks: ['a_shared', 'b_shared']\n" + ] + } + ], + "source": [ + "print(reg_dataclass.shock_info)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ff60922", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: a_shared\n" + ] + } + ], + "source": [ + "print(reg_dataclass.shock_info[\"a_shared\"])" + ] + }, + { + "cell_type": "markdown", + "id": "bdbe8f7c", + "metadata": {}, + "source": [ + "This pattern continues for data, parameters, and coords as shown below" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ead54287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['data_regression']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.data_names" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ba784a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data_regression': {'shape': (None, 2), 'dims': ('time', 'state_regression')}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.data_info" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "521382b9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data: ['data_regression']\n", + "needs exogenous data: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.data_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "85b7e774", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: data_regression\n", + "shape: (None, 2)\n", + "dims: ('time', 'state_regression')\n", + "is_exogenous: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.data_info[\"data_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e1ed9d7a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beta_regression': {'shape': (2,),\n", + " 'constraints': None,\n", + " 'dims': ('state_regression',)},\n", + " 'sigma_beta_regression': {'shape': (2,),\n", + " 'constraints': 'Positive',\n", + " 'dims': ('state_regression',)}}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_info" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8d194fe2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['beta_regression', 'sigma_beta_regression']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_names" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7fccad81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beta_regression': ('state_regression',),\n", + " 'sigma_beta_regression': ('state_regression',)}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_dims" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9787c813", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: ['beta_regression', 'sigma_beta_regression']\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "914e97da", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"beta_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "98875fd1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: sigma_beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n", + "constraints: Positive\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"sigma_beta_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a195cec5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'state_regression': ['a', 'b'], 'endog_regression': ['y']}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.coords" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "62622777", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "coordinates:\n", + " dimension: state_regression\n", + " labels: ['a', 'b']\n", + "\n", + " dimension: endog_regression\n", + " labels: ['y']\n", + "\n" + ] + } + ], + "source": [ + "print(reg_dataclass.coords)" + ] + }, + { + "cell_type": "markdown", + "id": "a79b845c", + "metadata": {}, + "source": [ + "# Mapping between items" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9484c709", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: ['beta_regression', 'sigma_beta_regression']\n" + ] + } + ], + "source": [ + "# Important to be able to map between parameters -> dimensions -> dimension labels\n", + "print(reg_dataclass.param_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "85573fa2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"beta_regression\"]) # Key is parameter name" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "32f56fd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dimension: state_regression\n", + "labels: ['a', 'b']\n" + ] + } + ], + "source": [ + "# dimension for parameter beta_regression is state_regression. Let's map to dimension labels\n", + "print(\n", + " reg_dataclass.coords[\n", + " reg_dataclass.param_info[\"beta_regression\"].dims[0] # Key is dimension name\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "35ae00a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dimension: state_regression\n", + "labels: ['a', 'b']\n" + ] + } + ], + "source": [ + "# Equivalently\n", + "print(reg_dataclass.coords[\"state_regression\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/models/structural/__init__.py b/pymc_extras/statespace/models/structural/__init__.py index f0bfb2f0a..8ef35c969 100644 --- a/pymc_extras/statespace/models/structural/__init__.py +++ b/pymc_extras/statespace/models/structural/__init__.py @@ -5,6 +5,9 @@ from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError from pymc_extras.statespace.models.structural.components.regression import RegressionComponent +from pymc_extras.statespace.models.structural.components.regression_dataclass import ( + RegressionComponent as RegressionComponentDataClass, +) from pymc_extras.statespace.models.structural.components.seasonality import ( FrequencySeasonality, TimeSeasonality, @@ -17,5 +20,6 @@ "LevelTrendComponent", "MeasurementError", "RegressionComponent", + "RegressionComponentDataClass", "TimeSeasonality", ] diff --git a/pymc_extras/statespace/models/structural/components/regression_dataclass.py b/pymc_extras/statespace/models/structural/components/regression_dataclass.py new file mode 100644 index 000000000..607b2469b --- /dev/null +++ b/pymc_extras/statespace/models/structural/components/regression_dataclass.py @@ -0,0 +1,539 @@ +from dataclasses import dataclass, field + +import numpy as np + +from pytensor import tensor as pt + +from pymc_extras.statespace.models.structural.core import Component +from pymc_extras.statespace.utils.constants import TIME_DIM + + +@dataclass +class ParameterProperty: + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + constraints: str | None = None + + def __str__(self): + base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}" + if self.constraints: + return base + f"\nconstraints: {self.constraints}" + return base + + +@dataclass +class ParameterProperties: + parameters: list[ParameterProperty] + + def get_parameter(self, name: str) -> ParameterProperty | None: + return next((p for p in self.parameters if p.name == name), None) + + def __getitem__(self, name: str) -> ParameterProperty: + result = next((p for p in self.parameters if p.name == name), None) + if result is None: + raise KeyError(f"No parameter named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(p.name == name for p in self.parameters) + + def __str__(self): + base = f"parameters: {[parameter.name for parameter in self.parameters]}" + return base + + +@dataclass +class DataProperty: + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + is_exogenous: bool + + def __str__(self): + base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}\nis_exogenous: {self.is_exogenous}" + return base + + +@dataclass +class DataProperties: + data: list[DataProperty] + needs_exogenous_data: bool = field(default=False, init=False) + + def __post_init__(self): + for d in self.data: + if d.is_exogenous: + self.needs_exogenous_data = True + + def get_data(self, name: str) -> DataProperty | None: + return next((d for d in self.data if d.name == name), None) + + def __getitem__(self, name: str) -> DataProperty: + result = next((d for d in self.data if d.name == name), None) + if result is None: + raise KeyError(f"No data named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(d.name == name for d in self.data) + + def __str__(self): + base = f"data: {[d.name for d in self.data]}\nneeds exogenous data: {self.needs_exogenous_data}" + return base + + +@dataclass +class CoordProperty: + dimension: str + labels: list[str] + + def __str__(self): + base = f"dimension: {self.dimension}\nlabels: {self.labels}" + return base + + +@dataclass +class CoordProperties: + coords: list[CoordProperty] + + def get_coord(self, dimension: str) -> CoordProperty | None: + return next((c for c in self.coords if c.dimension == dimension), None) + + def __getitem__(self, dimension: str) -> CoordProperty: + result = next((c for c in self.coords if c.dimension == dimension), None) + if result is None: + raise KeyError(f"No coordinate named '{dimension}'") + return result + + def __contains__(self, dimension: str) -> bool: + return any(c.dimension == dimension for c in self.coords) + + def __str__(self): + base = "coordinates:" + for coord in self.coords: + coord_str = str(coord) + indented = "\n".join(" " + line for line in coord_str.splitlines()) + base += "\n" + indented + "\n" + return base + + +@dataclass +class StateProperty: + name: str + observed: bool + shared: bool + + def __str__(self): + base = f"name: {self.name}\nobserved: {self.observed}\nshared: {self.shared}" + return base + + +@dataclass +class StateProperties: + states: list[StateProperty] + + def get_state(self, name: str) -> StateProperty | None: + return next((s for s in self.states if s.name == name), None) + + def __getitem__(self, name: str) -> StateProperty: + result = next((s for s in self.states if s.name == name), None) + if result is None: + raise KeyError(f"No state named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(s.name == name for s in self.states) + + def __str__(self): + base = f"states: {[state.name for state in self.states]}\nobserved: {[state.observed for state in self.states]}" + return base + + +@dataclass +class ShockProperty: + name: str + + def __str__(self): + base = f"name: {self.name}" + return base + + +@dataclass +class ShockProperties: + shocks: list[ShockProperty] + + def get_state(self, name: str) -> ShockProperty | None: + return next((shock for shock in self.shocks if shock.name == name), None) + + def __getitem__(self, name: str) -> ShockProperty: + result = next((shock for shock in self.shocks if shock.name == name), None) + if result is None: + raise KeyError(f"No shock named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(shock.name == name for shock in self.shocks) + + def __str__(self): + base = f"shocks: {[shock.name for shock in self.shocks]}" + return base + + +class RegressionComponent(Component): + r""" + Regression component for exogenous variables in a structural time series model + + Parameters + ---------- + k_exog : int | None, default None + Number of exogenous variables to include in the regression. Must be specified if + state_names is not provided. + + name : str | None, default "regression" + A name for this regression component. Used to label dimensions and coordinates. + + state_names : list[str] | None, default None + List of strings for regression coefficient labels. If provided, must be of length + k_exog. If None and k_exog is provided, coefficients will be named + "{name}_1, {name}_2, ...". + + observed_state_names : list[str] | None, default None + List of strings for observed state labels. If None, defaults to ["data"]. + + innovations : bool, default False + Whether to include stochastic innovations in the regression coefficients, + allowing them to vary over time. If True, coefficients follow a random walk. + + share_states: bool, default False + Whether latent states are shared across the observed states. If True, there will be only one set of latent + states, which are observed by all observed states. If False, each observed state has its own set of + latent states. + + Notes + ----- + This component implements regression with exogenous variables in a structural time series + model. The regression component can be expressed as: + + .. math:: + y_t = \beta_t^T x_t + \epsilon_t + + Where :math:`y_t` is the dependent variable, :math:`x_t` is the vector of exogenous + variables, :math:`\beta_t` is the vector of regression coefficients, and :math:`\epsilon_t` + is the error term. + + When ``innovations=False`` (default), the coefficients are constant over time: + :math:`\beta_t = \beta_0` for all t. + + When ``innovations=True``, the coefficients follow a random walk: + :math:`\beta_{t+1} = \beta_t + \eta_t`, where :math:`\eta_t \sim N(0, \Sigma_\beta)`. + + The component supports both univariate and multivariate regression. In the multivariate + case, separate coefficients are estimated for each endogenous variable (i.e time series). + + Examples + -------- + Simple regression with constant coefficients: + + .. code:: python + + from pymc_extras.statespace import structural as st + import pymc as pm + import pytensor.tensor as pt + + trend = st.LevelTrendComponent(order=1, innovations_order=1) + regression = st.RegressionComponent(k_exog=2, state_names=['intercept', 'slope']) + ss_mod = (trend + regression).build() + + with pm.Model(coords=ss_mod.coords) as model: + # Prior for regression coefficients + betas = pm.Normal('betas', dims=ss_mod.param_dims['beta_regression']) + + # Prior for trend innovations + sigma_trend = pm.Exponential('sigma_trend', 1) + + ss_mod.build_statespace_graph(data) + idata = pm.sample() + + Multivariate regression with time-varying coefficients: + - There are 2 exogenous variables (price and income effects) + - There are 2 endogenous variables (sales and revenue) + - The regression coefficients are allowed to vary over time (`innovations=True`) + + .. code:: python + + regression = st.RegressionComponent( + k_exog=2, + state_names=['price_effect', 'income_effect'], + observed_state_names=['sales', 'revenue'], + innovations=True + ) + + with pm.Model(coords=ss_mod.coords) as model: + betas = pm.Normal('betas', dims=ss_mod.param_dims['beta_regression']) + + # Innovation variance for time-varying coefficients + sigma_beta = pm.Exponential('sigma_beta', 1, dims=ss_mod.param_dims['sigma_beta_regression']) + + ss_mod.build_statespace_graph(data) + idata = pm.sample() + """ + + def __init__( + self, + k_exog: int | None = None, + name: str | None = "regression", + state_names: list[str] | None = None, + observed_state_names: list[str] | None = None, + innovations=False, + share_states: bool = False, + ): + self.share_states = share_states + + if observed_state_names is None: + observed_state_names = ["data"] + + self.innovations = innovations + k_exog = self._handle_input_data(k_exog, state_names, name) + + k_states = k_exog + k_endog = len(observed_state_names) + k_posdef = k_exog + + super().__init__( + name=name, + k_endog=k_endog, + k_states=k_states * k_endog if not share_states else k_states, + k_posdef=k_posdef * k_endog if not share_states else k_posdef, + state_names=self.state_names, + share_states=share_states, + observed_state_names=observed_state_names, + measurement_error=False, + combine_hidden_states=False, + exog_names=[f"data_{name}"], + obs_state_idxs=np.ones(k_states), + ) + + @staticmethod + def _get_state_names(k_exog: int | None, state_names: list[str] | None, name: str): + if k_exog is None and state_names is None: + raise ValueError("Must specify at least one of k_exog or state_names") + if state_names is not None and k_exog is not None: + if len(state_names) != k_exog: + raise ValueError(f"Expected {k_exog} state names, found {len(state_names)}") + elif k_exog is None: + k_exog = len(state_names) + else: + state_names = [f"{name}_{i + 1}" for i in range(k_exog)] + + return k_exog, state_names + + def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -> int: + k_exog, state_names = self._get_state_names(k_exog, state_names, name) + self.state_names = state_names + + return k_exog + + def make_symbolic_graph(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + + k_states = self.k_states // k_endog_effective + + betas = self.make_and_register_variable( + f"beta_{self.name}", shape=(k_endog, k_states) if k_endog_effective > 1 else (k_states,) + ) + regression_data = self.make_and_register_data(f"data_{self.name}", shape=(None, k_states)) + + self.ssm["initial_state", :] = betas.ravel() + self.ssm["transition", :, :] = pt.eye(self.k_states) + self.ssm["selection", :, :] = pt.eye(self.k_states) + + if self.share_states: + self.ssm["design"] = pt.specify_shape( + pt.join(1, *[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]), + (None, k_endog, self.k_states), + ) + else: + Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]) + self.ssm["design"] = pt.specify_shape( + Z, (None, k_endog, regression_data.type.shape[1] * k_endog) + ) + + if self.innovations: + sigma_beta = self.make_and_register_variable( + f"sigma_beta_{self.name}", + (k_states,) if k_endog_effective == 1 else (k_endog, k_states), + ) + row_idx, col_idx = np.diag_indices(self.k_states) + self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2 + + def _set_parameters(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + beta_param_name = f"beta_{self.name}" + beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,) + beta_param_dims = ( + (f"endog_{self.name}", f"state_{self.name}") + if k_endog_effective > 1 + else (f"state_{self.name}",) + ) + + beta_param_constraints = None + + if self.innovations: + sigma_param_name = f"sigma_beta_{self.name}" + sigma_param_dims = (f"state_{self.name}",) + sigma_param_shape = (k_states,) + sigma_param_constraints = "Positive" + + beta_parameter = ParameterProperty( + name=beta_param_name, + shape=beta_param_shape, + dims=beta_param_dims, + constraints=beta_param_constraints, + ) + + sigma_parameter = ParameterProperty( + name=sigma_param_name, + shape=sigma_param_shape, + dims=sigma_param_dims, + constraints=sigma_param_constraints, + ) + + self.param_info = ParameterProperties(parameters=[beta_parameter, sigma_parameter]) + + def _set_data(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + data_name = f"data_{self.name}" + data_shape = (None, k_states) + data_dims = (TIME_DIM, f"state_{self.name}") + + data_prop = DataProperty( + name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True + ) + self.data_info = DataProperties(data=[data_prop]) + + def _set_shocks(self) -> None: + if self.share_states: + shock_names = [f"{state_name}_shared" for state_name in self.state_names] + else: + shock_names = self.state_names + + self.shock_info = ShockProperties(shocks=[ShockProperty(name=name) for name in shock_names]) + + def _set_states(self) -> None: + self.base_names = self.state_names + + if self.share_states: + state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] + self.state_names = StateProperties( + states=[ + StateProperty(name=name, observed=True, shared=True) for name in state_names + ] + ) + else: + state_names = [ + f"{name}[{obs_name}]" + for obs_name in self.observed_state_names + for name in self.base_names + ] + self.state_names = StateProperties( + states=[ + StateProperty(name=name, observed=True, shared=False) for name in state_names + ] + ) + + def _set_coords(self) -> None: + regression_state_prop = CoordProperty( + dimension=f"state_{self.name}", labels=[state for state in self.base_names] + ) + endogenous_state_prop = CoordProperty( + dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] + ) + + self.coords = CoordProperties(coords=[regression_state_prop, endogenous_state_prop]) + + def populate_component_properties(self) -> None: + # k_endog_eff, k_states = self._effective_shape_info() + + # 1. Set parameter info + self._set_parameters() + + # 2. Set data info + self._set_data() + + # 3. Set shock info + self._set_shocks() + + # 4. Set states info + self._set_states() + + # 5. Set coordinates info + self._set_coords() + + # def populate_component_properties(self) -> None: + # k_endog = self.k_endog + # k_endog_effective = 1 if self.share_states else k_endog + + # k_states = self.k_states // k_endog_effective + + # if self.share_states: + # self.shock_names = [f"{state_name}_shared" for state_name in self.state_names] + # else: + # self.shock_names = self.state_names + + # self.param_names = [f"beta_{self.name}"] + # self.data_names = [f"data_{self.name}"] + # self.param_dims = { + # f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}") + # if k_endog_effective > 1 + # else (f"state_{self.name}",) + # } + + # base_names = self.state_names + + # if self.share_states: + # self.state_names = [f"{name}[{self.name}_shared]" for name in base_names] + # else: + # self.state_names = [ + # f"{name}[{obs_name}]" + # for obs_name in self.observed_state_names + # for name in base_names + # ] + + # self.param_info = { + # f"beta_{self.name}": { + # "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), + # "constraints": None, + # "dims": (f"endog_{self.name}", f"state_{self.name}") + # if k_endog_effective > 1 + # else (f"state_{self.name}",), + # }, + # } + + # self.data_info = { + # f"data_{self.name}": { + # "shape": (None, k_states), + # "dims": (TIME_DIM, f"state_{self.name}"), + # }, + # } + # self.coords = { + # f"state_{self.name}": base_names, + # f"endog_{self.name}": self.observed_state_names, + # } + + # if self.innovations: + # self.param_names += [f"sigma_beta_{self.name}"] + # self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",) + # self.param_info[f"sigma_beta_{self.name}"] = { + # "shape": (k_states,), + # "constraints": "Positive", + # "dims": (f"state_{self.name}",) + # if k_endog_effective == 1 + # else (f"endog_{self.name}", f"state_{self.name}"), + # } diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index a2718251b..621daa313 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -595,7 +595,7 @@ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable: An error is raised if the provided name has already been registered, or if the name is not present in the ``param_names`` property. """ - if name not in self.param_names: + if name not in self.param_info: raise ValueError( f"{name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." @@ -632,7 +632,7 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: An error is raised if the provided name has already been registered, or if the name is not present in the ``data_names`` property. """ - if name not in self.data_names: + if name not in self.data_info: raise ValueError( f"{name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." From 7f32a48055fce5a9bb7951739a9be7bc4d799ffe Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 7 Nov 2025 08:07:27 -0600 Subject: [PATCH 2/8] Iterate on proposal --- pymc_extras/statespace/core/properties.py | 176 ++++++++++++++++++++++ tests/statespace/core/test_properties.py | 119 +++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 pymc_extras/statespace/core/properties.py create mode 100644 tests/statespace/core/test_properties.py diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py new file mode 100644 index 000000000..ecaa1691c --- /dev/null +++ b/pymc_extras/statespace/core/properties.py @@ -0,0 +1,176 @@ +from collections.abc import Iterator +from dataclasses import dataclass, fields +from typing import Generic, Self, TypeVar + +from pymc_extras.statespace.core import PyMCStateSpace +from pymc_extras.statespace.models.structural.core import Component +from pymc_extras.statespace.utils.constants import ( + ALL_STATE_AUX_DIM, + ALL_STATE_DIM, + OBS_STATE_AUX_DIM, + OBS_STATE_DIM, + SHOCK_AUX_DIM, + SHOCK_DIM, +) + + +@dataclass(frozen=True) +class Property: + def __str__(self) -> str: + return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self)) + + +T = TypeVar("T", bound=Property) + + +@dataclass(frozen=True) +class Info(Generic[T]): + items: tuple[T, ...] + key_field: str = "name" + _index: dict[str, T] | None = None + + def __post_init__(self): + index = {} + missing_attr = [] + for item in self.items: + if not hasattr(item, self.key_field): + missing_attr.append(item) + continue + key = getattr(item, self.key_field) + if key in index: + raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") + index[key] = item + if missing_attr: + raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}") + object.__setattr__(self, "_index", index) + + def _key(self, item: T) -> str: + return getattr(item, self.key_field) + + def get(self, key: str, default=None) -> T | None: + return self._index.get(key, default) + + def __getitem__(self, key: str) -> T: + try: + return self._index[key] + except KeyError as e: + available = ", ".join(self._index.keys()) + raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e + + def __contains__(self, key: object) -> bool: + return key in self._index + + def __iter__(self) -> Iterator[str]: + return iter(self._index) + + def __len__(self) -> int: + return len(self._index) + + def __str__(self) -> str: + return f"{self.key_field}s: {list(self._index.keys())}" + + @property + def names(self) -> tuple[str, ...]: + return tuple(self._index.keys()) + + +@dataclass(frozen=True) +class Parameter(Property): + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + constraints: str | None = None + + +@dataclass(frozen=True) +class ParameterInfo(Info[Parameter]): + def __init__(self, parameters: list[Parameter]): + super().__init__(items=tuple(parameters), key_field="name") + + +@dataclass(frozen=True) +class Data(Property): + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + is_exogenous: bool + + +@dataclass(frozen=True) +class DataInfo(Info[Data]): + def __init__(self, data: list[Data]): + super().__init__(items=tuple(data), key_field="name") + + @property + def needs_exogenous_data(self) -> bool: + return any(d.is_exogenous for d in self.items) + + def __str__(self) -> str: + return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}" + + +@dataclass(frozen=True) +class Coord(Property): + dimension: str + labels: tuple[str, ...] + + +@dataclass(frozen=True) +class CoordInfo(Info[Coord]): + def __init__(self, coords: list[Coord]): + super().__init__(items=tuple(coords), key_field="dimension") + + def __str__(self) -> str: + base = "coordinates:" + for coord in self.items: + coord_str = str(coord) + indented = "\n".join(" " + line for line in coord_str.splitlines()) + base += "\n" + indented + "\n" + return base + + @classmethod + def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self: + states = tuple(model.state_names) + obs_states = tuple(model.observed_state_names) + shocks = tuple(model.shock_names) + + dim_to_labels = ( + (ALL_STATE_DIM, states), + (ALL_STATE_AUX_DIM, states), + (OBS_STATE_DIM, obs_states), + (OBS_STATE_AUX_DIM, obs_states), + (SHOCK_DIM, shocks), + (SHOCK_AUX_DIM, shocks), + ) + + coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels] + return cls(coords) + + +@dataclass(frozen=True) +class State(Property): + name: str + observed: bool + shared: bool + + +@dataclass(frozen=True) +class StateInfo(Info[State]): + def __init__(self, states: list[State]): + super().__init__(items=tuple(states), key_field="name") + + def __str__(self) -> str: + return ( + f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}" + ) + + +@dataclass(frozen=True) +class Shock(Property): + name: str + + +@dataclass(frozen=True) +class ShockInfo(Info[Shock]): + def __init__(self, shocks: list[Shock]): + super().__init__(items=tuple(shocks), key_field="name") diff --git a/tests/statespace/core/test_properties.py b/tests/statespace/core/test_properties.py new file mode 100644 index 000000000..7f7cb8ae3 --- /dev/null +++ b/tests/statespace/core/test_properties.py @@ -0,0 +1,119 @@ +import pytest + +from pymc_extras.statespace.core.properties import ( + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) +from pymc_extras.statespace.utils.constants import ( + ALL_STATE_AUX_DIM, + ALL_STATE_DIM, + OBS_STATE_AUX_DIM, + OBS_STATE_DIM, + SHOCK_AUX_DIM, + SHOCK_DIM, +) + + +def test_property_str_formats_fields(): + p = Parameter(name="alpha", shape=(2,), dims=("param",)) + s = str(p).splitlines() + assert s == [ + "name: alpha", + "shape: (2,)", + "dims: ('param',)", + "constraints: None", + ] + + +def test_info_lookup_contains_and_missing_key(): + params = [ + Parameter("a", (1,), ("d",)), + Parameter("b", (2,), ("d",)), + Parameter("c", (3,), ("d",)), + ] + info = ParameterInfo(params) + + assert info.get("b").name == "b" + assert info["a"].shape == (1,) + assert "c" in info + + with pytest.raises(KeyError) as e: + _ = info["missing"] + assert "No name 'missing'" in str(e.value) + + +def test_data_info_needs_exogenous_and_str(): + data = [ + Data("price", (10,), ("time",), is_exogenous=False), + Data("x", (10,), ("time",), is_exogenous=True), + ] + info = DataInfo(data) + + assert info.needs_exogenous_data is True + s = str(info) + assert "data: ['price', 'x']" in s + assert "needs exogenous data: True" in s + + no_exog = DataInfo([Data("y", (10,), ("time",), is_exogenous=False)]) + assert no_exog.needs_exogenous_data is False + + +def test_coord_info_make_defaults_from_component_and_types(): + class DummyComponent: + state_names = ["x1", "x2"] + observed_state_names = ["x2"] + shock_names = ["eps1"] + + ci = CoordInfo.default_coords_from_model(DummyComponent()) + + expected = [ + (ALL_STATE_DIM, ("x1", "x2")), + (ALL_STATE_AUX_DIM, ("x1", "x2")), + (OBS_STATE_DIM, ("x2",)), + (OBS_STATE_AUX_DIM, ("x2",)), + (SHOCK_DIM, ("eps1",)), + (SHOCK_AUX_DIM, ("eps1",)), + ] + + assert len(ci.items) == 6 + for dim, labels in expected: + assert dim in ci + assert ci[dim].labels == labels + assert isinstance(ci[dim].labels, tuple) + + +def test_state_info_and_shockinfo_basic(): + states = [ + State("x1", observed=True, shared=False), + State("x2", observed=False, shared=True), + ] + state_info = StateInfo(states) + assert state_info["x1"].observed is True + s = str(state_info) + + assert "states: ['x1', 'x2']" in s + assert "observed: [True, False]" in s + + shocks = [Shock("s1"), Shock("s2")] + shock_info = ShockInfo(shocks) + + assert "s1" in shock_info + assert shock_info["s2"].name == "s2" + + +def test_info_is_iterable_and_unpackable(): + items = [Parameter("p1", (1,), ("d",)), Parameter("p2", (2,), ("d",))] + info = ParameterInfo(items) + + names = info.names + assert names == ("p1", "p2") + + a, b = info.items + assert a.name == "p1" and b.name == "p2" From d65fc0aa165f45e134b38f37c635b8a7ebfd1e51 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 7 Nov 2025 11:45:17 -0600 Subject: [PATCH 3/8] Fix iterator, add `to_dict` method to `CoordsInfo` --- pymc_extras/statespace/core/properties.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py index ecaa1691c..c15253450 100644 --- a/pymc_extras/statespace/core/properties.py +++ b/pymc_extras/statespace/core/properties.py @@ -61,10 +61,10 @@ def __contains__(self, key: object) -> bool: return key in self._index def __iter__(self) -> Iterator[str]: - return iter(self._index) + return iter(self.items) def __len__(self) -> int: - return len(self._index) + return len(self.items) def __str__(self) -> str: return f"{self.key_field}s: {list(self._index.keys())}" @@ -146,6 +146,9 @@ def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self: coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels] return cls(coords) + def to_dict(self): + return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0} + @dataclass(frozen=True) class State(Property): From c6a48fcbb62b7fce3b2fcd51f129407c3699eb1d Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 7 Nov 2025 11:46:05 -0600 Subject: [PATCH 4/8] Add `observed_states` helper to `StateInfo` --- pymc_extras/statespace/core/properties.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py index c15253450..d133ccc4c 100644 --- a/pymc_extras/statespace/core/properties.py +++ b/pymc_extras/statespace/core/properties.py @@ -167,6 +167,10 @@ def __str__(self) -> str: f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}" ) + @property + def observed_states(self) -> tuple[State, ...]: + return tuple(s for s in self.items if s.observed) + @dataclass(frozen=True) class Shock(Property): From 92e333fd57d7082bf45d65674311ad71a5325184 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sun, 9 Nov 2025 09:15:54 -0700 Subject: [PATCH 5/8] made necessary changes to get the regression component test to pass using the new dataclasses API --- pymc_extras/statespace/core/properties.py | 82 +++++++++- .../structural/components/regression.py | 146 ++++++++++++------ .../statespace/models/structural/core.py | 96 +++++++++--- pymc_extras/statespace/models/utilities.py | 26 +++- .../structural/components/test_regression.py | 24 ++- .../statespace/models/structural/conftest.py | 12 +- 6 files changed, 293 insertions(+), 93 deletions(-) diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py index d133ccc4c..56e875f94 100644 --- a/pymc_extras/statespace/core/properties.py +++ b/pymc_extras/statespace/core/properties.py @@ -3,7 +3,6 @@ from typing import Generic, Self, TypeVar from pymc_extras.statespace.core import PyMCStateSpace -from pymc_extras.statespace.models.structural.core import Component from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, @@ -87,6 +86,21 @@ class ParameterInfo(Info[Parameter]): def __init__(self, parameters: list[Parameter]): super().__init__(items=tuple(parameters), key_field="name") + def add(self, parameter: Parameter) -> "ParameterInfo": + # return a new ParameterInfo with parameter appended + return ParameterInfo(parameters=[*list(self.items), parameter]) + + def merge(self, other: "ParameterInfo") -> "ParameterInfo": + """Combine parameters from two ParameterInfo objects.""" + if not isinstance(other, ParameterInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate parameter names found: {overlapping}") + + return ParameterInfo(parameters=list(self.items) + list(other.items)) + @dataclass(frozen=True) class Data(Property): @@ -108,6 +122,21 @@ def needs_exogenous_data(self) -> bool: def __str__(self) -> str: return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}" + def add(self, data: Data) -> "DataInfo": + # return a new DataInfo with data appended + return DataInfo(data=[*list(self.items), data]) + + def merge(self, other: "DataInfo") -> "DataInfo": + """Combine data from two DataInfo objects.""" + if not isinstance(other, DataInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate data names found: {overlapping}") + + return DataInfo(data=list(self.items) + list(other.items)) + @dataclass(frozen=True) class Coord(Property): @@ -129,7 +158,11 @@ def __str__(self) -> str: return base @classmethod - def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self: + def default_coords_from_model( + cls, model: PyMCStateSpace + ) -> ( + Self + ): # TODO: Need to figure out how to include Component type was causing circular import issues states = tuple(model.state_names) obs_states = tuple(model.observed_state_names) shocks = tuple(model.shock_names) @@ -149,6 +182,21 @@ def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self: def to_dict(self): return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0} + def add(self, coord: Coord) -> "CoordInfo": + # return a new CoordInfo with data appended + return CoordInfo(coords=[*list(self.items), coord]) + + def merge(self, other: "CoordInfo") -> "CoordInfo": + """Combine data from two CoordInfo objects.""" + if not isinstance(other, CoordInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate coord names found: {overlapping}") + + return CoordInfo(coords=list(self.items) + list(other.items)) + @dataclass(frozen=True) class State(Property): @@ -171,6 +219,21 @@ def __str__(self) -> str: def observed_states(self) -> tuple[State, ...]: return tuple(s for s in self.items if s.observed) + def add(self, state: State) -> "StateInfo": + # return a new StateInfo with state appended + return StateInfo(states=[*list(self.items), state]) + + def merge(self, other: "StateInfo") -> "StateInfo": + """Combine states from two StateInfo objects.""" + if not isinstance(other, StateInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate state names found: {overlapping}") + + return StateInfo(states=list(self.items) + list(other.items)) + @dataclass(frozen=True) class Shock(Property): @@ -181,3 +244,18 @@ class Shock(Property): class ShockInfo(Info[Shock]): def __init__(self, shocks: list[Shock]): super().__init__(items=tuple(shocks), key_field="name") + + def add(self, shock: Shock) -> "ShockInfo": + # return a new ShockInfo with shock appended + return ShockInfo(shocks=[*list(self.items), shock]) + + def merge(self, other: "ShockInfo") -> "ShockInfo": + """Combine shocks from two ShockInfo objects.""" + if not isinstance(other, ShockInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate shock names found: {overlapping}") + + return ShockInfo(shocks=list(self.items) + list(other.items)) diff --git a/pymc_extras/statespace/models/structural/components/regression.py b/pymc_extras/statespace/models/structural/components/regression.py index 5620b1ea7..1444b902b 100644 --- a/pymc_extras/statespace/models/structural/components/regression.py +++ b/pymc_extras/statespace/models/structural/components/regression.py @@ -2,6 +2,18 @@ from pytensor import tensor as pt +from pymc_extras.statespace.core.properties import ( + Coord, + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) from pymc_extras.statespace.models.structural.core import Component from pymc_extras.statespace.utils.constants import TIME_DIM @@ -194,64 +206,110 @@ def make_symbolic_graph(self) -> None: row_idx, col_idx = np.diag_indices(self.k_states) self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2 - def populate_component_properties(self) -> None: + def _set_parameters(self) -> None: k_endog = self.k_endog k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + beta_param_name = f"beta_{self.name}" + beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,) + beta_param_dims = ( + (f"endog_{self.name}", f"state_{self.name}") + if k_endog_effective > 1 + else (f"state_{self.name}",) + ) + + beta_param_constraints = None + beta_parameter = Parameter( + name=beta_param_name, + shape=beta_param_shape, + dims=beta_param_dims, + constraints=beta_param_constraints, + ) + if self.innovations: + sigma_param_name = f"sigma_beta_{self.name}" + sigma_param_dims = (f"state_{self.name}",) + sigma_param_shape = (k_states,) + sigma_param_constraints = "Positive" + + sigma_parameter = Parameter( + name=sigma_param_name, + shape=sigma_param_shape, + dims=sigma_param_dims, + constraints=sigma_param_constraints, + ) + + self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter]) + self.param_names = self.param_info.names + else: + self.param_info = ParameterInfo(parameters=[beta_parameter]) + self.param_names = self.param_info.names + + def _set_data(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog k_states = self.k_states // k_endog_effective + data_name = f"data_{self.name}" + data_shape = (None, k_states) + data_dims = (TIME_DIM, f"state_{self.name}") + + data_prop = Data(name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True) + self.data_info = DataInfo(data=[data_prop]) + self.data_names = self.data_info.names + + def _set_shocks(self) -> None: if self.share_states: - self.shock_names = [f"{state_name}_shared" for state_name in self.state_names] + shock_names = [f"{state_name}_shared" for state_name in self.state_names] else: - self.shock_names = self.state_names + shock_names = self.state_names - self.param_names = [f"beta_{self.name}"] - self.data_names = [f"data_{self.name}"] - self.param_dims = { - f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",) - } + self.shock_info = ShockInfo(shocks=[Shock(name=name) for name in shock_names]) + self.shock_names = self.shock_info.names - base_names = self.state_names + def _set_states(self) -> None: + self.base_names = self.state_names if self.share_states: - self.state_names = [f"{name}[{self.name}_shared]" for name in base_names] + state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] + self.state_info = StateInfo( + states=[State(name=name, observed=True, shared=True) for name in state_names] + ) + self.state_names = self.state_info.names else: - self.state_names = [ + state_names = [ f"{name}[{obs_name}]" for obs_name in self.observed_state_names - for name in base_names + for name in self.base_names ] + self.state_info = StateInfo( + states=[State(name=name, observed=True, shared=False) for name in state_names] + ) + self.state_names = self.state_info.names - self.param_info = { - f"beta_{self.name}": { - "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), - "constraints": None, - "dims": (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",), - }, - } - - self.data_info = { - f"data_{self.name}": { - "shape": (None, k_states), - "dims": (TIME_DIM, f"state_{self.name}"), - }, - } - self.coords = { - f"state_{self.name}": base_names, - f"endog_{self.name}": self.observed_state_names, - } + def _set_coords(self) -> None: + regression_state_coord = Coord( + dimension=f"state_{self.name}", labels=[state for state in self.base_names] + ) + endogenous_state_coord = Coord( + dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] + ) - if self.innovations: - self.param_names += [f"sigma_beta_{self.name}"] - self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",) - self.param_info[f"sigma_beta_{self.name}"] = { - "shape": (k_states,), - "constraints": "Positive", - "dims": (f"state_{self.name}",) - if k_endog_effective == 1 - else (f"endog_{self.name}", f"state_{self.name}"), - } + self.coords = CoordInfo(coords=[regression_state_coord, endogenous_state_coord]) + + def populate_component_properties(self) -> None: + # Set parameter info + self._set_parameters() + + # Set data info + self._set_data() + + # Set shock info + self._set_shocks() + + # Set states info + self._set_states() + + # Set coordinates info + self._set_coords() diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 621daa313..7b159e7f7 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -2,6 +2,7 @@ import logging from collections.abc import Sequence +from dataclasses import is_dataclass from itertools import pairwise from typing import Any @@ -12,6 +13,10 @@ from pytensor import tensor as pt from pymc_extras.statespace.core import PyMCStateSpace, PytensorRepresentation +from pymc_extras.statespace.core.properties import ( + Parameter, + ParameterInfo, +) from pymc_extras.statespace.models.utilities import ( add_tensors_by_dim_labels, conform_time_varying_and_time_invariant_matrices, @@ -136,6 +141,7 @@ class StructuralTimeSeries(PyMCStateSpace): methods (2nd ed.). Oxford University Press. """ + # TODO need to discuss cutting some of these args down. All the _names args are already inside of _info def __init__( self, ssm: PytensorRepresentation, @@ -150,6 +156,8 @@ def __init__( coords: dict[str, Sequence], param_info: dict[str, dict[str, Any]], data_info: dict[str, dict[str, Any]], + shock_info: dict[str, dict[str, Any]], + state_info: dict[str, dict[str, Any]], component_info: dict[str, dict[str, Any]], measurement_error: bool, name_to_variable: dict[str, Variable], @@ -165,7 +173,7 @@ def __init__( k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog param_names, param_dims, param_info = self._add_inital_state_cov_to_properties( - param_names, param_dims, param_info, k_states + param_info, k_states ) self._state_names = self._strip_data_names_if_unambiguous(state_names, k_endog) @@ -175,13 +183,13 @@ def __init__( self._param_dims = param_dims default_coords = make_default_coords(self) - coords.update(default_coords) + coords = coords.merge(default_coords) - self._coords = { - k: self._strip_data_names_if_unambiguous(v, k_endog) for k, v in coords.items() - } - self._param_info = param_info.copy() - self._data_info = data_info.copy() + self._coord_info = coords + self._param_info = param_info # .copy() #TODO add __copy__ to base class + self._data_info = data_info # .copy() + self._shock_info = shock_info + self._state_info = state_info self.measurement_error = measurement_error super().__init__( @@ -236,16 +244,25 @@ def _strip_data_names_if_unambiguous(self, names: list[str], k_endog: int): return names @staticmethod - def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states): - param_names += ["P0"] - param_dims["P0"] = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) - param_info["P0"] = { - "shape": (k_states, k_states), - "constraints": "Positive semi-definite", - "dims": param_dims["P0"], - } + def _add_inital_state_cov_to_properties(param_info, k_states): + initial_state_cov_name = "P0" + initial_state_cov_shape = (k_states, k_states) + initial_state_cov_dims = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) + initial_state_cov_constraints = "Positive semi-definite" + + initial_state_cov_param = Parameter( + name=initial_state_cov_name, + shape=initial_state_cov_shape, + dims=initial_state_cov_dims, + constraints=initial_state_cov_constraints, + ) + + if is_dataclass(param_info): + param_info = param_info.add(initial_state_cov_param) + else: + param_info = ParameterInfo(parameters=[initial_state_cov_param]) - return param_names, param_dims, param_info + return param_info.names, [p.dims for p in param_info], param_info @property def param_names(self): @@ -271,9 +288,9 @@ def shock_names(self): def param_dims(self): return self._param_dims - @property + @property # TODO discuss naming convention _info and need to clean up type hints def coords(self) -> dict[str, Sequence]: - return self._coords + return self._coord_info @property def param_info(self) -> dict[str, dict[str, Any]]: @@ -283,6 +300,14 @@ def param_info(self) -> dict[str, dict[str, Any]]: def data_info(self) -> dict[str, dict[str, Any]]: return self._data_info + @property + def state_info(self) -> dict[str, dict[str, Any]]: + return self._state_info + + @property + def shock_info(self) -> dict[str, dict[str, Any]]: + return self._shock_info + def make_symbolic_graph(self) -> None: """ Assign placeholder pytensor variables among statespace matrices in positions where PyMC variables will go. @@ -540,6 +565,8 @@ def __init__( self.param_info = {} self.data_info = {} + self.shock_info = {} + self.state_info = {} self.param_counts = {} @@ -648,6 +675,21 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: self._name_to_data[name] = placeholder return placeholder + def _set_parameters(self) -> None: + raise NotImplementedError + + def _set_data(self) -> None: + raise NotImplementedError + + def _set_shocks(self) -> None: + raise NotImplementedError + + def _set_states(self) -> None: + raise NotImplementedError + + def _set_coords(self) -> None: + raise NotImplementedError + def make_symbolic_graph(self) -> None: raise NotImplementedError @@ -764,15 +806,17 @@ def _combine_property(self, other, name, allow_duplicates=True): self_prop = getattr(self, name) other_prop = getattr(other, name) + # TODO discuss limiting the types we get here to only a dataclass type. By making the dataclasses immutable we now have to handle for tuples too. + if not isinstance(self_prop, type(other_prop)): raise TypeError( f"Property {name} of {self} and {other} are not the same and cannot be combined. Found " f"{type(self_prop)} for {self} and {type(other_prop)} for {other}'" ) - if not isinstance(self_prop, list | dict): + if not is_dataclass(self_prop) and not isinstance(self_prop, list | tuple | dict): raise TypeError( - f"All component properties are expected to be lists or dicts, but found {type(self_prop)}" + f"All component properties are expected to be dataclasses, but found {type(self_prop)}" f"for property {name} of {self} and {type(other_prop)} for {other}'" ) @@ -784,6 +828,12 @@ def _combine_property(self, other, name, allow_duplicates=True): new_prop = self_prop.copy() new_prop.update(other_prop) return new_prop + # TODO need to handle allow_duplicates but want to wait for above discussion first to see if we can cut down to just dataclass types + elif isinstance(self_prop, tuple): + new_prop = self_prop + other_prop + return new_prop + elif is_dataclass(self_prop): + return self_prop.merge(other_prop) def _combine_component_info(self, other): combined_info = {} @@ -817,6 +867,8 @@ def __add__(self, other): shock_names = self._combine_property(other, "shock_names") param_info = self._combine_property(other, "param_info") data_info = self._combine_property(other, "data_info") + shock_info = self._combine_property(other, "shock_info") + state_info = self._combine_property(other, "state_info") param_dims = self._combine_property(other, "param_dims") coords = self._combine_property(other, "coords") exog_names = self._combine_property(other, "exog_names") @@ -854,6 +906,8 @@ def __add__(self, other): ("param_dims", param_dims), ("param_info", param_info), ("data_info", data_info), + ("shock_info", shock_info), + ("state_info", state_info), ("exog_names", exog_names), ("_name_to_variable", _name_to_variable), ("_name_to_data", _name_to_data), @@ -908,6 +962,8 @@ def build( coords=self.coords, param_info=self.param_info, data_info=self.data_info, + shock_info=self.shock_info, + state_info=self.state_info, component_info=self._component_info, measurement_error=self.measurement_error, exog_names=self.exog_names, diff --git a/pymc_extras/statespace/models/utilities.py b/pymc_extras/statespace/models/utilities.py index 33be8d47d..cab8f3b3c 100644 --- a/pymc_extras/statespace/models/utilities.py +++ b/pymc_extras/statespace/models/utilities.py @@ -5,6 +5,7 @@ from pytensor.tensor import TensorVariable +from pymc_extras.statespace.core.properties import Coord, CoordInfo from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, @@ -19,14 +20,23 @@ def make_default_coords(ss_mod): - coords = { - ALL_STATE_DIM: ss_mod.state_names, - ALL_STATE_AUX_DIM: ss_mod.state_names, - OBS_STATE_DIM: ss_mod.observed_states, - OBS_STATE_AUX_DIM: ss_mod.observed_states, - SHOCK_DIM: ss_mod.shock_names, - SHOCK_AUX_DIM: ss_mod.shock_names, - } + ALL_STATE_COORD = Coord(dimension=ALL_STATE_DIM, labels=ss_mod.state_names) + ALL_STATE_AUX_COORD = Coord(dimension=ALL_STATE_AUX_DIM, labels=ss_mod.state_names) + OBS_STATE_COORD = Coord(dimension=OBS_STATE_DIM, labels=ss_mod.observed_states) + OBS_STATE_AUX_COORD = Coord(dimension=OBS_STATE_AUX_DIM, labels=ss_mod.observed_states) + SHOCK_COORD = Coord(dimension=SHOCK_DIM, labels=ss_mod.shock_names) + SHOCK_AUX_COORD = Coord(dimension=SHOCK_AUX_DIM, labels=ss_mod.shock_names) + + coords = CoordInfo( + coords=[ + ALL_STATE_COORD, + ALL_STATE_AUX_COORD, + OBS_STATE_COORD, + OBS_STATE_AUX_COORD, + SHOCK_COORD, + SHOCK_AUX_COORD, + ] + ) return coords diff --git a/tests/statespace/models/structural/components/test_regression.py b/tests/statespace/models/structural/components/test_regression.py index c1732997d..ffde3348f 100644 --- a/tests/statespace/models/structural/components/test_regression.py +++ b/tests/statespace/models/structural/components/test_regression.py @@ -66,7 +66,7 @@ def test_exogenous_component(self, rng, regression_data, innovations): mod = mod.build(verbose=False) _assert_basic_coords_correct(mod) - assert mod.coords["state_exog"] == ["feature_1", "feature_2"] + assert mod.coords["state_exog"].labels == ["feature_1", "feature_2"] if innovations: # Check that sigma_beta parameter is included @@ -125,7 +125,7 @@ def test_regression_with_multiple_observed_states(self, rng, regression_data, in assert_allclose(x[0, 2:], params["beta_exog"][1], atol=ATOL, rtol=RTOL) mod = mod.build(verbose=False) - assert mod.coords["state_exog"] == ["feature_1", "feature_2"] + assert mod.coords["state_exog"].labels == ["feature_1", "feature_2"] Z = mod.ssm["design"].eval({"data_exog": regression_data}) vec_block_diag = np.vectorize(block_diag, signature="(n,m),(o,p)->(q,r)") @@ -164,8 +164,8 @@ def test_add_regression_components_with_multiple_observed_states( ) mod = (reg1 + reg2).build(verbose=False) - assert mod.coords["state_exog1"] == ["a", "b"] - assert mod.coords["state_exog2"] == ["c"] + assert mod.coords["state_exog1"].labels == ["a", "b"] + assert mod.coords["state_exog2"].labels == ["c"] Z = mod.ssm["design"].eval( { @@ -211,7 +211,7 @@ def test_filter_scans_time_varying_design_matrix(self, rng, time_series_data, in reg = st.RegressionComponent(state_names=["a", "b"], name="exog", innovations=innovations) mod = reg.build(verbose=False) - with pm.Model(coords=mod.coords) as m: + with pm.Model(coords=mod.coords.to_dict()) as m: data_exog = pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) @@ -249,14 +249,12 @@ def test_regression_multiple_shared_construction(): assert mod.k_states == 1 assert mod.k_posdef == 1 - assert mod.coords["state_regression"] == ["A"] - assert mod.coords["endog_regression"] == ["data_1", "data_2"] + assert mod.coords["state_regression"].labels == ["A"] + assert mod.coords["endog_regression"].labels == ["data_1", "data_2"] - assert mod.state_names == [ - "A[regression_shared]", - ] + assert mod.state_names == ("A[regression_shared]",) - assert mod.shock_names == ["A_shared"] + assert mod.shock_names == ("A_shared",) data = np.random.standard_normal(size=(10, 1)) Z = mod.ssm["design"].eval({"data_regression": data}) @@ -312,8 +310,8 @@ def test_regression_mixed_shared_and_not_shared(): assert mod.k_states == 4 assert mod.k_posdef == 4 - assert mod.state_names == ["A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]"] - assert mod.shock_names == ["A", "B_shared", "C_shared"] + assert mod.state_names == ("A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]") + assert mod.shock_names == ("A", "B_shared", "C_shared") data_joint = np.random.standard_normal(size=(10, 2)) data_individual = np.random.standard_normal(size=(10, 1)) diff --git a/tests/statespace/models/structural/conftest.py b/tests/statespace/models/structural/conftest.py index b9e58ca68..15dac710d 100644 --- a/tests/statespace/models/structural/conftest.py +++ b/tests/statespace/models/structural/conftest.py @@ -19,11 +19,11 @@ def rng(): def _assert_basic_coords_correct(mod): - assert mod.coords[ALL_STATE_DIM] == mod.state_names - assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names - assert mod.coords[SHOCK_DIM] == mod.shock_names - assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names + assert mod.coords[ALL_STATE_DIM].labels == mod.state_names + assert mod.coords[ALL_STATE_AUX_DIM].labels == mod.state_names + assert mod.coords[SHOCK_DIM].labels == mod.shock_names + assert mod.coords[SHOCK_AUX_DIM].labels == mod.shock_names expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ["data"] - assert mod.coords[OBS_STATE_DIM] == expected_obs - assert mod.coords[OBS_STATE_AUX_DIM] == expected_obs + assert mod.coords[OBS_STATE_DIM].labels == expected_obs + assert mod.coords[OBS_STATE_AUX_DIM].labels == expected_obs From a183c717fc961f4562251fe5d8abc8b4591ba7dd Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 15 Nov 2025 08:56:12 -0700 Subject: [PATCH 6/8] 1. Updated dataclasses to include copy method and replaced raise on duplicate with warning 2. removed unnecessary imports from __init__ after deleting regression_dataclass 3. updated components and structural classes to only utilize dataclasses and pull other objects from _info dataclasses 4. updated tests to conform to dataclass api --- pymc_extras/statespace/core/properties.py | 56 +- .../statespace/models/structural/__init__.py | 4 - .../structural/components/regression.py | 80 ++- .../components/regression_dataclass.py | 539 ------------------ .../statespace/models/structural/core.py | 214 ++++--- .../structural/components/test_regression.py | 15 +- .../statespace/models/structural/conftest.py | 14 +- 7 files changed, 191 insertions(+), 731 deletions(-) delete mode 100644 pymc_extras/statespace/models/structural/components/regression_dataclass.py diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py index 56e875f94..35c1394d0 100644 --- a/pymc_extras/statespace/core/properties.py +++ b/pymc_extras/statespace/core/properties.py @@ -1,4 +1,7 @@ +import warnings + from collections.abc import Iterator +from copy import deepcopy from dataclasses import dataclass, fields from typing import Generic, Self, TypeVar @@ -36,8 +39,8 @@ def __post_init__(self): missing_attr.append(item) continue key = getattr(item, self.key_field) - if key in index: - raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") + # if key in index: + # raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states index[key] = item if missing_attr: raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}") @@ -72,6 +75,9 @@ def __str__(self) -> str: def names(self) -> tuple[str, ...]: return tuple(self._index.keys()) + def copy(self) -> "Info[T]": + return deepcopy(self) + @dataclass(frozen=True) class Parameter(Property): @@ -90,13 +96,13 @@ def add(self, parameter: Parameter) -> "ParameterInfo": # return a new ParameterInfo with parameter appended return ParameterInfo(parameters=[*list(self.items), parameter]) - def merge(self, other: "ParameterInfo") -> "ParameterInfo": + def merge(self, other: "ParameterInfo", allow_duplicates: bool = False) -> "ParameterInfo": """Combine parameters from two ParameterInfo objects.""" if not isinstance(other, ParameterInfo): raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo") overlapping = set(self.names) & set(other.names) - if overlapping: + if overlapping and not allow_duplicates: raise ValueError(f"Duplicate parameter names found: {overlapping}") return ParameterInfo(parameters=list(self.items) + list(other.items)) @@ -119,6 +125,10 @@ def __init__(self, data: list[Data]): def needs_exogenous_data(self) -> bool: return any(d.is_exogenous for d in self.items) + @property + def exogenous_names(self) -> tuple[str, ...]: + return tuple(d.name for d in self.items if d.is_exogenous) + def __str__(self) -> str: return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}" @@ -126,13 +136,13 @@ def add(self, data: Data) -> "DataInfo": # return a new DataInfo with data appended return DataInfo(data=[*list(self.items), data]) - def merge(self, other: "DataInfo") -> "DataInfo": + def merge(self, other: "DataInfo", allow_duplicates: bool = False) -> "DataInfo": """Combine data from two DataInfo objects.""" if not isinstance(other, DataInfo): raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo") overlapping = set(self.names) & set(other.names) - if overlapping: + if overlapping and not allow_duplicates: raise ValueError(f"Duplicate data names found: {overlapping}") return DataInfo(data=list(self.items) + list(other.items)) @@ -164,7 +174,7 @@ def default_coords_from_model( Self ): # TODO: Need to figure out how to include Component type was causing circular import issues states = tuple(model.state_names) - obs_states = tuple(model.observed_state_names) + obs_states = tuple(model.observed_states) shocks = tuple(model.shock_names) dim_to_labels = ( @@ -186,13 +196,13 @@ def add(self, coord: Coord) -> "CoordInfo": # return a new CoordInfo with data appended return CoordInfo(coords=[*list(self.items), coord]) - def merge(self, other: "CoordInfo") -> "CoordInfo": + def merge(self, other: "CoordInfo", allow_duplicates: bool = False) -> "CoordInfo": """Combine data from two CoordInfo objects.""" if not isinstance(other, CoordInfo): raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo") overlapping = set(self.names) & set(other.names) - if overlapping: + if overlapping and not allow_duplicates: raise ValueError(f"Duplicate coord names found: {overlapping}") return CoordInfo(coords=list(self.items) + list(other.items)) @@ -216,21 +226,37 @@ def __str__(self) -> str: ) @property - def observed_states(self) -> tuple[State, ...]: + def observed_states(self) -> tuple[State, ...]: # Is this needed?? return tuple(s for s in self.items if s.observed) + @property + def observed_state_names(self) -> tuple[State, ...]: + return tuple(s.name for s in self.items if s.observed) + + @property + def unobserved_state_names(self) -> tuple[State, ...]: + return tuple(s.name for s in self.items if not s.observed) + def add(self, state: State) -> "StateInfo": # return a new StateInfo with state appended return StateInfo(states=[*list(self.items), state]) - def merge(self, other: "StateInfo") -> "StateInfo": + def merge(self, other: "StateInfo", allow_duplicates: bool = False) -> "StateInfo": """Combine states from two StateInfo objects.""" if not isinstance(other, StateInfo): raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo") overlapping = set(self.names) & set(other.names) - if overlapping: - raise ValueError(f"Duplicate state names found: {overlapping}") + if overlapping and not allow_duplicates: + # This is necessary for shared states + warnings.warn( + f"Duplicate state names found: {overlapping}. Merge will ONLY retain unique states", + UserWarning, + ) + return StateInfo( + states=list(self.items) + + [item for item in other.items if item.name not in overlapping] + ) return StateInfo(states=list(self.items) + list(other.items)) @@ -249,13 +275,13 @@ def add(self, shock: Shock) -> "ShockInfo": # return a new ShockInfo with shock appended return ShockInfo(shocks=[*list(self.items), shock]) - def merge(self, other: "ShockInfo") -> "ShockInfo": + def merge(self, other: "ShockInfo", allow_duplicates: bool = False) -> "ShockInfo": """Combine shocks from two ShockInfo objects.""" if not isinstance(other, ShockInfo): raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo") overlapping = set(self.names) & set(other.names) - if overlapping: + if overlapping and not allow_duplicates: raise ValueError(f"Duplicate shock names found: {overlapping}") return ShockInfo(shocks=list(self.items) + list(other.items)) diff --git a/pymc_extras/statespace/models/structural/__init__.py b/pymc_extras/statespace/models/structural/__init__.py index 8ef35c969..f0bfb2f0a 100644 --- a/pymc_extras/statespace/models/structural/__init__.py +++ b/pymc_extras/statespace/models/structural/__init__.py @@ -5,9 +5,6 @@ from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError from pymc_extras.statespace.models.structural.components.regression import RegressionComponent -from pymc_extras.statespace.models.structural.components.regression_dataclass import ( - RegressionComponent as RegressionComponentDataClass, -) from pymc_extras.statespace.models.structural.components.seasonality import ( FrequencySeasonality, TimeSeasonality, @@ -20,6 +17,5 @@ "LevelTrendComponent", "MeasurementError", "RegressionComponent", - "RegressionComponentDataClass", "TimeSeasonality", ] diff --git a/pymc_extras/statespace/models/structural/components/regression.py b/pymc_extras/statespace/models/structural/components/regression.py index 1444b902b..f5ecf3d47 100644 --- a/pymc_extras/statespace/models/structural/components/regression.py +++ b/pymc_extras/statespace/models/structural/components/regression.py @@ -211,33 +211,23 @@ def _set_parameters(self) -> None: k_endog_effective = 1 if self.share_states else k_endog k_states = self.k_states // k_endog_effective - beta_param_name = f"beta_{self.name}" - beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,) - beta_param_dims = ( - (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",) - ) - - beta_param_constraints = None beta_parameter = Parameter( - name=beta_param_name, - shape=beta_param_shape, - dims=beta_param_dims, - constraints=beta_param_constraints, + name=f"beta_{self.name}", + shape=(k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), + dims=( + (f"endog_{self.name}", f"state_{self.name}") + if k_endog_effective > 1 + else (f"state_{self.name}",) + ), + constraints=None, ) if self.innovations: - sigma_param_name = f"sigma_beta_{self.name}" - sigma_param_dims = (f"state_{self.name}",) - sigma_param_shape = (k_states,) - sigma_param_constraints = "Positive" - sigma_parameter = Parameter( - name=sigma_param_name, - shape=sigma_param_shape, - dims=sigma_param_dims, - constraints=sigma_param_constraints, + name=f"sigma_beta_{self.name}", + shape=(k_states,), + dims=(f"state_{self.name}",), + constraints="Positive", ) self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter]) @@ -251,11 +241,12 @@ def _set_data(self) -> None: k_endog_effective = 1 if self.share_states else k_endog k_states = self.k_states // k_endog_effective - data_name = f"data_{self.name}" - data_shape = (None, k_states) - data_dims = (TIME_DIM, f"state_{self.name}") - - data_prop = Data(name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True) + data_prop = Data( + name=f"data_{self.name}", + shape=(None, k_states), + dims=(TIME_DIM, f"state_{self.name}"), + is_exogenous=True, + ) self.data_info = DataInfo(data=[data_prop]) self.data_names = self.data_info.names @@ -274,9 +265,17 @@ def _set_states(self) -> None: if self.share_states: state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] self.state_info = StateInfo( - states=[State(name=name, observed=True, shared=True) for name in state_names] + states=[State(name=name, observed=False, shared=True) for name in state_names] + ) + self.state_info = self.state_info.merge( + StateInfo( + states=[ + State(name=name, observed=True, shared=False) + for name in self.observed_state_names + ] + ) ) - self.state_names = self.state_info.names + self.state_names = self.state_info.unobserved_state_names else: state_names = [ f"{name}[{obs_name}]" @@ -284,9 +283,17 @@ def _set_states(self) -> None: for name in self.base_names ] self.state_info = StateInfo( - states=[State(name=name, observed=True, shared=False) for name in state_names] + states=[State(name=name, observed=False, shared=False) for name in state_names] ) - self.state_names = self.state_info.names + self.state_info = self.state_info.merge( + StateInfo( + states=[ + State(name=name, observed=True, shared=False) + for name in self.observed_state_names + ] + ) + ) + self.state_names = self.state_info.unobserved_state_names def _set_coords(self) -> None: regression_state_coord = Coord( @@ -296,20 +303,11 @@ def _set_coords(self) -> None: dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] ) - self.coords = CoordInfo(coords=[regression_state_coord, endogenous_state_coord]) + self.coords_info = CoordInfo(coords=[regression_state_coord, endogenous_state_coord]) def populate_component_properties(self) -> None: - # Set parameter info self._set_parameters() - - # Set data info self._set_data() - - # Set shock info self._set_shocks() - - # Set states info self._set_states() - - # Set coordinates info self._set_coords() diff --git a/pymc_extras/statespace/models/structural/components/regression_dataclass.py b/pymc_extras/statespace/models/structural/components/regression_dataclass.py deleted file mode 100644 index 607b2469b..000000000 --- a/pymc_extras/statespace/models/structural/components/regression_dataclass.py +++ /dev/null @@ -1,539 +0,0 @@ -from dataclasses import dataclass, field - -import numpy as np - -from pytensor import tensor as pt - -from pymc_extras.statespace.models.structural.core import Component -from pymc_extras.statespace.utils.constants import TIME_DIM - - -@dataclass -class ParameterProperty: - name: str - shape: tuple[int, ...] - dims: tuple[str, ...] - constraints: str | None = None - - def __str__(self): - base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}" - if self.constraints: - return base + f"\nconstraints: {self.constraints}" - return base - - -@dataclass -class ParameterProperties: - parameters: list[ParameterProperty] - - def get_parameter(self, name: str) -> ParameterProperty | None: - return next((p for p in self.parameters if p.name == name), None) - - def __getitem__(self, name: str) -> ParameterProperty: - result = next((p for p in self.parameters if p.name == name), None) - if result is None: - raise KeyError(f"No parameter named '{name}'") - return result - - def __contains__(self, name: str) -> bool: - return any(p.name == name for p in self.parameters) - - def __str__(self): - base = f"parameters: {[parameter.name for parameter in self.parameters]}" - return base - - -@dataclass -class DataProperty: - name: str - shape: tuple[int, ...] - dims: tuple[str, ...] - is_exogenous: bool - - def __str__(self): - base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}\nis_exogenous: {self.is_exogenous}" - return base - - -@dataclass -class DataProperties: - data: list[DataProperty] - needs_exogenous_data: bool = field(default=False, init=False) - - def __post_init__(self): - for d in self.data: - if d.is_exogenous: - self.needs_exogenous_data = True - - def get_data(self, name: str) -> DataProperty | None: - return next((d for d in self.data if d.name == name), None) - - def __getitem__(self, name: str) -> DataProperty: - result = next((d for d in self.data if d.name == name), None) - if result is None: - raise KeyError(f"No data named '{name}'") - return result - - def __contains__(self, name: str) -> bool: - return any(d.name == name for d in self.data) - - def __str__(self): - base = f"data: {[d.name for d in self.data]}\nneeds exogenous data: {self.needs_exogenous_data}" - return base - - -@dataclass -class CoordProperty: - dimension: str - labels: list[str] - - def __str__(self): - base = f"dimension: {self.dimension}\nlabels: {self.labels}" - return base - - -@dataclass -class CoordProperties: - coords: list[CoordProperty] - - def get_coord(self, dimension: str) -> CoordProperty | None: - return next((c for c in self.coords if c.dimension == dimension), None) - - def __getitem__(self, dimension: str) -> CoordProperty: - result = next((c for c in self.coords if c.dimension == dimension), None) - if result is None: - raise KeyError(f"No coordinate named '{dimension}'") - return result - - def __contains__(self, dimension: str) -> bool: - return any(c.dimension == dimension for c in self.coords) - - def __str__(self): - base = "coordinates:" - for coord in self.coords: - coord_str = str(coord) - indented = "\n".join(" " + line for line in coord_str.splitlines()) - base += "\n" + indented + "\n" - return base - - -@dataclass -class StateProperty: - name: str - observed: bool - shared: bool - - def __str__(self): - base = f"name: {self.name}\nobserved: {self.observed}\nshared: {self.shared}" - return base - - -@dataclass -class StateProperties: - states: list[StateProperty] - - def get_state(self, name: str) -> StateProperty | None: - return next((s for s in self.states if s.name == name), None) - - def __getitem__(self, name: str) -> StateProperty: - result = next((s for s in self.states if s.name == name), None) - if result is None: - raise KeyError(f"No state named '{name}'") - return result - - def __contains__(self, name: str) -> bool: - return any(s.name == name for s in self.states) - - def __str__(self): - base = f"states: {[state.name for state in self.states]}\nobserved: {[state.observed for state in self.states]}" - return base - - -@dataclass -class ShockProperty: - name: str - - def __str__(self): - base = f"name: {self.name}" - return base - - -@dataclass -class ShockProperties: - shocks: list[ShockProperty] - - def get_state(self, name: str) -> ShockProperty | None: - return next((shock for shock in self.shocks if shock.name == name), None) - - def __getitem__(self, name: str) -> ShockProperty: - result = next((shock for shock in self.shocks if shock.name == name), None) - if result is None: - raise KeyError(f"No shock named '{name}'") - return result - - def __contains__(self, name: str) -> bool: - return any(shock.name == name for shock in self.shocks) - - def __str__(self): - base = f"shocks: {[shock.name for shock in self.shocks]}" - return base - - -class RegressionComponent(Component): - r""" - Regression component for exogenous variables in a structural time series model - - Parameters - ---------- - k_exog : int | None, default None - Number of exogenous variables to include in the regression. Must be specified if - state_names is not provided. - - name : str | None, default "regression" - A name for this regression component. Used to label dimensions and coordinates. - - state_names : list[str] | None, default None - List of strings for regression coefficient labels. If provided, must be of length - k_exog. If None and k_exog is provided, coefficients will be named - "{name}_1, {name}_2, ...". - - observed_state_names : list[str] | None, default None - List of strings for observed state labels. If None, defaults to ["data"]. - - innovations : bool, default False - Whether to include stochastic innovations in the regression coefficients, - allowing them to vary over time. If True, coefficients follow a random walk. - - share_states: bool, default False - Whether latent states are shared across the observed states. If True, there will be only one set of latent - states, which are observed by all observed states. If False, each observed state has its own set of - latent states. - - Notes - ----- - This component implements regression with exogenous variables in a structural time series - model. The regression component can be expressed as: - - .. math:: - y_t = \beta_t^T x_t + \epsilon_t - - Where :math:`y_t` is the dependent variable, :math:`x_t` is the vector of exogenous - variables, :math:`\beta_t` is the vector of regression coefficients, and :math:`\epsilon_t` - is the error term. - - When ``innovations=False`` (default), the coefficients are constant over time: - :math:`\beta_t = \beta_0` for all t. - - When ``innovations=True``, the coefficients follow a random walk: - :math:`\beta_{t+1} = \beta_t + \eta_t`, where :math:`\eta_t \sim N(0, \Sigma_\beta)`. - - The component supports both univariate and multivariate regression. In the multivariate - case, separate coefficients are estimated for each endogenous variable (i.e time series). - - Examples - -------- - Simple regression with constant coefficients: - - .. code:: python - - from pymc_extras.statespace import structural as st - import pymc as pm - import pytensor.tensor as pt - - trend = st.LevelTrendComponent(order=1, innovations_order=1) - regression = st.RegressionComponent(k_exog=2, state_names=['intercept', 'slope']) - ss_mod = (trend + regression).build() - - with pm.Model(coords=ss_mod.coords) as model: - # Prior for regression coefficients - betas = pm.Normal('betas', dims=ss_mod.param_dims['beta_regression']) - - # Prior for trend innovations - sigma_trend = pm.Exponential('sigma_trend', 1) - - ss_mod.build_statespace_graph(data) - idata = pm.sample() - - Multivariate regression with time-varying coefficients: - - There are 2 exogenous variables (price and income effects) - - There are 2 endogenous variables (sales and revenue) - - The regression coefficients are allowed to vary over time (`innovations=True`) - - .. code:: python - - regression = st.RegressionComponent( - k_exog=2, - state_names=['price_effect', 'income_effect'], - observed_state_names=['sales', 'revenue'], - innovations=True - ) - - with pm.Model(coords=ss_mod.coords) as model: - betas = pm.Normal('betas', dims=ss_mod.param_dims['beta_regression']) - - # Innovation variance for time-varying coefficients - sigma_beta = pm.Exponential('sigma_beta', 1, dims=ss_mod.param_dims['sigma_beta_regression']) - - ss_mod.build_statespace_graph(data) - idata = pm.sample() - """ - - def __init__( - self, - k_exog: int | None = None, - name: str | None = "regression", - state_names: list[str] | None = None, - observed_state_names: list[str] | None = None, - innovations=False, - share_states: bool = False, - ): - self.share_states = share_states - - if observed_state_names is None: - observed_state_names = ["data"] - - self.innovations = innovations - k_exog = self._handle_input_data(k_exog, state_names, name) - - k_states = k_exog - k_endog = len(observed_state_names) - k_posdef = k_exog - - super().__init__( - name=name, - k_endog=k_endog, - k_states=k_states * k_endog if not share_states else k_states, - k_posdef=k_posdef * k_endog if not share_states else k_posdef, - state_names=self.state_names, - share_states=share_states, - observed_state_names=observed_state_names, - measurement_error=False, - combine_hidden_states=False, - exog_names=[f"data_{name}"], - obs_state_idxs=np.ones(k_states), - ) - - @staticmethod - def _get_state_names(k_exog: int | None, state_names: list[str] | None, name: str): - if k_exog is None and state_names is None: - raise ValueError("Must specify at least one of k_exog or state_names") - if state_names is not None and k_exog is not None: - if len(state_names) != k_exog: - raise ValueError(f"Expected {k_exog} state names, found {len(state_names)}") - elif k_exog is None: - k_exog = len(state_names) - else: - state_names = [f"{name}_{i + 1}" for i in range(k_exog)] - - return k_exog, state_names - - def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -> int: - k_exog, state_names = self._get_state_names(k_exog, state_names, name) - self.state_names = state_names - - return k_exog - - def make_symbolic_graph(self) -> None: - k_endog = self.k_endog - k_endog_effective = 1 if self.share_states else k_endog - - k_states = self.k_states // k_endog_effective - - betas = self.make_and_register_variable( - f"beta_{self.name}", shape=(k_endog, k_states) if k_endog_effective > 1 else (k_states,) - ) - regression_data = self.make_and_register_data(f"data_{self.name}", shape=(None, k_states)) - - self.ssm["initial_state", :] = betas.ravel() - self.ssm["transition", :, :] = pt.eye(self.k_states) - self.ssm["selection", :, :] = pt.eye(self.k_states) - - if self.share_states: - self.ssm["design"] = pt.specify_shape( - pt.join(1, *[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]), - (None, k_endog, self.k_states), - ) - else: - Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]) - self.ssm["design"] = pt.specify_shape( - Z, (None, k_endog, regression_data.type.shape[1] * k_endog) - ) - - if self.innovations: - sigma_beta = self.make_and_register_variable( - f"sigma_beta_{self.name}", - (k_states,) if k_endog_effective == 1 else (k_endog, k_states), - ) - row_idx, col_idx = np.diag_indices(self.k_states) - self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2 - - def _set_parameters(self) -> None: - k_endog = self.k_endog - k_endog_effective = 1 if self.share_states else k_endog - k_states = self.k_states // k_endog_effective - - beta_param_name = f"beta_{self.name}" - beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,) - beta_param_dims = ( - (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",) - ) - - beta_param_constraints = None - - if self.innovations: - sigma_param_name = f"sigma_beta_{self.name}" - sigma_param_dims = (f"state_{self.name}",) - sigma_param_shape = (k_states,) - sigma_param_constraints = "Positive" - - beta_parameter = ParameterProperty( - name=beta_param_name, - shape=beta_param_shape, - dims=beta_param_dims, - constraints=beta_param_constraints, - ) - - sigma_parameter = ParameterProperty( - name=sigma_param_name, - shape=sigma_param_shape, - dims=sigma_param_dims, - constraints=sigma_param_constraints, - ) - - self.param_info = ParameterProperties(parameters=[beta_parameter, sigma_parameter]) - - def _set_data(self) -> None: - k_endog = self.k_endog - k_endog_effective = 1 if self.share_states else k_endog - k_states = self.k_states // k_endog_effective - - data_name = f"data_{self.name}" - data_shape = (None, k_states) - data_dims = (TIME_DIM, f"state_{self.name}") - - data_prop = DataProperty( - name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True - ) - self.data_info = DataProperties(data=[data_prop]) - - def _set_shocks(self) -> None: - if self.share_states: - shock_names = [f"{state_name}_shared" for state_name in self.state_names] - else: - shock_names = self.state_names - - self.shock_info = ShockProperties(shocks=[ShockProperty(name=name) for name in shock_names]) - - def _set_states(self) -> None: - self.base_names = self.state_names - - if self.share_states: - state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] - self.state_names = StateProperties( - states=[ - StateProperty(name=name, observed=True, shared=True) for name in state_names - ] - ) - else: - state_names = [ - f"{name}[{obs_name}]" - for obs_name in self.observed_state_names - for name in self.base_names - ] - self.state_names = StateProperties( - states=[ - StateProperty(name=name, observed=True, shared=False) for name in state_names - ] - ) - - def _set_coords(self) -> None: - regression_state_prop = CoordProperty( - dimension=f"state_{self.name}", labels=[state for state in self.base_names] - ) - endogenous_state_prop = CoordProperty( - dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] - ) - - self.coords = CoordProperties(coords=[regression_state_prop, endogenous_state_prop]) - - def populate_component_properties(self) -> None: - # k_endog_eff, k_states = self._effective_shape_info() - - # 1. Set parameter info - self._set_parameters() - - # 2. Set data info - self._set_data() - - # 3. Set shock info - self._set_shocks() - - # 4. Set states info - self._set_states() - - # 5. Set coordinates info - self._set_coords() - - # def populate_component_properties(self) -> None: - # k_endog = self.k_endog - # k_endog_effective = 1 if self.share_states else k_endog - - # k_states = self.k_states // k_endog_effective - - # if self.share_states: - # self.shock_names = [f"{state_name}_shared" for state_name in self.state_names] - # else: - # self.shock_names = self.state_names - - # self.param_names = [f"beta_{self.name}"] - # self.data_names = [f"data_{self.name}"] - # self.param_dims = { - # f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}") - # if k_endog_effective > 1 - # else (f"state_{self.name}",) - # } - - # base_names = self.state_names - - # if self.share_states: - # self.state_names = [f"{name}[{self.name}_shared]" for name in base_names] - # else: - # self.state_names = [ - # f"{name}[{obs_name}]" - # for obs_name in self.observed_state_names - # for name in base_names - # ] - - # self.param_info = { - # f"beta_{self.name}": { - # "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), - # "constraints": None, - # "dims": (f"endog_{self.name}", f"state_{self.name}") - # if k_endog_effective > 1 - # else (f"state_{self.name}",), - # }, - # } - - # self.data_info = { - # f"data_{self.name}": { - # "shape": (None, k_states), - # "dims": (TIME_DIM, f"state_{self.name}"), - # }, - # } - # self.coords = { - # f"state_{self.name}": base_names, - # f"endog_{self.name}": self.observed_state_names, - # } - - # if self.innovations: - # self.param_names += [f"sigma_beta_{self.name}"] - # self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",) - # self.param_info[f"sigma_beta_{self.name}"] = { - # "shape": (k_states,), - # "constraints": "Positive", - # "dims": (f"state_{self.name}",) - # if k_endog_effective == 1 - # else (f"endog_{self.name}", f"state_{self.name}"), - # } diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 7b159e7f7..8d777bc88 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -14,14 +14,20 @@ from pymc_extras.statespace.core import PyMCStateSpace, PytensorRepresentation from pymc_extras.statespace.core.properties import ( + CoordInfo, + Data, + DataInfo, Parameter, ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, ) from pymc_extras.statespace.models.utilities import ( add_tensors_by_dim_labels, conform_time_varying_and_time_invariant_matrices, join_tensors_by_dim_labels, - make_default_coords, ) from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, @@ -141,23 +147,15 @@ class StructuralTimeSeries(PyMCStateSpace): methods (2nd ed.). Oxford University Press. """ - # TODO need to discuss cutting some of these args down. All the _names args are already inside of _info def __init__( self, ssm: PytensorRepresentation, name: str, - state_names: list[str], - observed_state_names: list[str], - data_names: list[str], - shock_names: list[str], - param_names: list[str], - exog_names: list[str], - param_dims: dict[str, tuple[int]], - coords: dict[str, Sequence], - param_info: dict[str, dict[str, Any]], - data_info: dict[str, dict[str, Any]], - shock_info: dict[str, dict[str, Any]], - state_info: dict[str, dict[str, Any]], + coords_info: CoordInfo, + param_info: ParameterInfo, + data_info: DataInfo, + shock_info: ShockInfo, + state_info: StateInfo, component_info: dict[str, dict[str, Any]], measurement_error: bool, name_to_variable: dict[str, Variable], @@ -169,27 +167,32 @@ def __init__( name = "StructuralTimeSeries" if name is None else name self._name = name - self._observed_state_names = observed_state_names + self._observed_state_names = state_info.observed_state_names k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog param_names, param_dims, param_info = self._add_inital_state_cov_to_properties( param_info, k_states ) - self._state_names = self._strip_data_names_if_unambiguous(state_names, k_endog) - self._data_names = self._strip_data_names_if_unambiguous(data_names, k_endog) - self._shock_names = self._strip_data_names_if_unambiguous(shock_names, k_endog) + self._state_names = self._strip_data_names_if_unambiguous( + state_info.unobserved_state_names, k_endog + ) + self._data_names = self._strip_data_names_if_unambiguous( + [d.name for d in data_info if not d.is_exogenous], k_endog + ) + self._shock_names = self._strip_data_names_if_unambiguous(shock_info.names, k_endog) self._param_names = self._strip_data_names_if_unambiguous(param_names, k_endog) self._param_dims = param_dims - default_coords = make_default_coords(self) - coords = coords.merge(default_coords) + default_coords = coords_info.default_coords_from_model(self) + coords_info = coords_info.merge(default_coords) - self._coord_info = coords - self._param_info = param_info # .copy() #TODO add __copy__ to base class - self._data_info = data_info # .copy() - self._shock_info = shock_info - self._state_info = state_info + # TODO: discuss if copying is still needed since these are now immutable + self._coord_info = coords_info.copy() + self._param_info = param_info.copy() + self._data_info = data_info.copy() + self._shock_info = shock_info.copy() + self._state_info = state_info.copy() self.measurement_error = measurement_error super().__init__( @@ -218,8 +221,8 @@ def __init__( self._name_to_variable = name_to_variable.copy() self._name_to_data = name_to_data.copy() - self._exog_names = exog_names.copy() - self._needs_exog_data = len(exog_names) > 0 + self._exog_names = data_info.exogenous_names + self._needs_exog_data = data_info.needs_exogenous_data P0 = self.make_and_register_variable("P0", shape=(self.k_states, self.k_states)) self.ssm["initial_state_cov"] = P0 @@ -235,26 +238,21 @@ def _strip_data_names_if_unambiguous(self, names: list[str], k_endog: int): """ if k_endog == 1: [data_name] = self.observed_states - return [ + return tuple( name.replace(f"[{data_name}]", "") if isinstance(name, str) else name for name in names - ] + ) else: return names @staticmethod def _add_inital_state_cov_to_properties(param_info, k_states): - initial_state_cov_name = "P0" - initial_state_cov_shape = (k_states, k_states) - initial_state_cov_dims = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) - initial_state_cov_constraints = "Positive semi-definite" - initial_state_cov_param = Parameter( - name=initial_state_cov_name, - shape=initial_state_cov_shape, - dims=initial_state_cov_dims, - constraints=initial_state_cov_constraints, + name="P0", + shape=(k_states, k_states), + dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM), + constraints="Positive semi-definite", ) if is_dataclass(param_info): @@ -272,6 +270,10 @@ def param_names(self): def data_names(self) -> list[str]: return self._data_names + @property + def exog_names(self) -> list[str]: + return self._exog_names + @property def state_names(self): return self._state_names @@ -288,25 +290,9 @@ def shock_names(self): def param_dims(self): return self._param_dims - @property # TODO discuss naming convention _info and need to clean up type hints - def coords(self) -> dict[str, Sequence]: - return self._coord_info - - @property - def param_info(self) -> dict[str, dict[str, Any]]: - return self._param_info - - @property - def data_info(self) -> dict[str, dict[str, Any]]: - return self._data_info - - @property - def state_info(self) -> dict[str, dict[str, Any]]: - return self._state_info - @property - def shock_info(self) -> dict[str, dict[str, Any]]: - return self._shock_info + def coords(self) -> dict[str, Sequence]: + return self._coord_info.to_dict() def make_symbolic_graph(self) -> None: """ @@ -550,25 +536,43 @@ def __init__( self.k_posdef = k_posdef self.measurement_error = measurement_error - self.state_names = list(state_names) if state_names is not None else [] - self.observed_state_names = ( - list(observed_state_names) if observed_state_names is not None else [] + self.param_info = ParameterInfo( + parameters=[ + Parameter(name=n, shape=(1,), dims=(f"{n}_placeholder")) + for n in (param_names or []) + ] + ) + self.data_info = DataInfo( + data=[ + Data(name=n, shape=(None, 1), dims=(f"{n}_placeholder"), is_exogenous=False) + for n in (data_names or []) + ] + + [ + Data(name=n, shape=(None, 1), dims=(f"{n}_placeholder"), is_exogenous=True) + for n in (exog_names or []) + ] + ) + self.shock_info = ShockInfo(shocks=[Shock(name=n) for n in (shock_names or [])]) + self.state_info = StateInfo( + states=[State(name=n, observed=False, shared=share_states) for n in (state_names or [])] + + [ + State(name=n, observed=True, shared=share_states) + for n in (observed_state_names or []) + ] ) - self.data_names = list(data_names) if data_names is not None else [] - self.shock_names = list(shock_names) if shock_names is not None else [] - self.param_names = list(param_names) if param_names is not None else [] - self.exog_names = list(exog_names) if exog_names is not None else [] + self.coord_info = CoordInfo(coords=[]) - self.needs_exog_data = len(self.exog_names) > 0 - self.coords = {} - self.param_dims = {} + self.state_names = self.state_info.unobserved_state_names + self.observed_state_names = self.state_info.observed_state_names + self.param_names = self.param_info.names + self.data_names = [d.name for d in self.data_info if not d.is_exogenous] + self.exog_names = self.data_info.exogenous_names + self.shock_names = self.shock_info.names - self.param_info = {} - self.data_info = {} - self.shock_info = {} - self.state_info = {} + self.coords = self.coord_info.to_dict() + self.param_dims = [p.dims for p in self.param_info] - self.param_counts = {} + self.needs_exog_data = self.data_info.needs_exogenous_data if representation is None: self.ssm = PytensorRepresentation(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef) @@ -701,10 +705,8 @@ def _get_combined_shapes(self, other): k_posdef = self.k_posdef + other.k_posdef # To count endog states, we have to count unique names between the two components. - combined_states = self._combine_property( - other, "observed_state_names", allow_duplicates=False - ) - k_endog = len(combined_states) + combined_states = self._combine_property(other, "state_info", allow_duplicates=False) + k_endog = len(combined_states.observed_state_names) return k_states, k_posdef, k_endog @@ -806,34 +808,19 @@ def _combine_property(self, other, name, allow_duplicates=True): self_prop = getattr(self, name) other_prop = getattr(other, name) - # TODO discuss limiting the types we get here to only a dataclass type. By making the dataclasses immutable we now have to handle for tuples too. - if not isinstance(self_prop, type(other_prop)): raise TypeError( f"Property {name} of {self} and {other} are not the same and cannot be combined. Found " f"{type(self_prop)} for {self} and {type(other_prop)} for {other}'" ) - if not is_dataclass(self_prop) and not isinstance(self_prop, list | tuple | dict): + if not is_dataclass(self_prop): raise TypeError( f"All component properties are expected to be dataclasses, but found {type(self_prop)}" f"for property {name} of {self} and {type(other_prop)} for {other}'" ) - if isinstance(self_prop, list) and allow_duplicates: - return self_prop + other_prop - elif isinstance(self_prop, list) and not allow_duplicates: - return self_prop + [x for x in other_prop if x not in self_prop] - elif isinstance(self_prop, dict): - new_prop = self_prop.copy() - new_prop.update(other_prop) - return new_prop - # TODO need to handle allow_duplicates but want to wait for above discussion first to see if we can cut down to just dataclass types - elif isinstance(self_prop, tuple): - new_prop = self_prop + other_prop - return new_prop - elif is_dataclass(self_prop): - return self_prop.merge(other_prop) + return self_prop.merge(other_prop, allow_duplicates) def _combine_component_info(self, other): combined_info = {} @@ -857,24 +844,23 @@ def _make_combined_name(self): return name def __add__(self, other): - state_names = self._combine_property(other, "state_names") - data_names = self._combine_property(other, "data_names") - observed_state_names = self._combine_property( - other, "observed_state_names", allow_duplicates=False - ) - - param_names = self._combine_property(other, "param_names") - shock_names = self._combine_property(other, "shock_names") param_info = self._combine_property(other, "param_info") data_info = self._combine_property(other, "data_info") shock_info = self._combine_property(other, "shock_info") state_info = self._combine_property(other, "state_info") - param_dims = self._combine_property(other, "param_dims") - coords = self._combine_property(other, "coords") - exog_names = self._combine_property(other, "exog_names") + coords_info = self._combine_property(other, "coords_info") + + state_names = state_info.unobserved_state_names + observed_state_names = state_info.observed_state_names + data_names = [d.name for d in data_info if not d.is_exogenous] + exog_names = data_info.exogenous_names + param_names = param_info.names + shock_names = shock_info.names + param_dims = [p.dims for p in param_info] - _name_to_variable = self._combine_property(other, "_name_to_variable") - _name_to_data = self._combine_property(other, "_name_to_data") + # TODO: Figure out how to handle these items in dataclasses + # _name_to_variable = self._combine_property(other, "_name_to_variable") + # _name_to_data = self._combine_property(other, "_name_to_data") measurement_error = any([self.measurement_error, other.measurement_error]) @@ -901,16 +887,15 @@ def __add__(self, other): ("data_names", data_names), ("param_names", param_names), ("shock_names", shock_names), - ("param_dims", param_dims), - ("coords", coords), - ("param_dims", param_dims), + ("coords_info", coords_info), ("param_info", param_info), ("data_info", data_info), ("shock_info", shock_info), ("state_info", state_info), ("exog_names", exog_names), - ("_name_to_variable", _name_to_variable), - ("_name_to_data", _name_to_data), + ("param_dims", param_dims), + # ("_name_to_variable", _name_to_variable), # TODO: Need to figure out how to handle these objects + # ("_name_to_data", _name_to_data), ] for prop, value in names_and_props: @@ -953,20 +938,13 @@ def build( return StructuralTimeSeries( self.ssm, name=name, - state_names=self.state_names, - observed_state_names=self.observed_state_names, - data_names=self.data_names, - shock_names=self.shock_names, - param_names=self.param_names, - param_dims=self.param_dims, - coords=self.coords, + coords_info=self.coords_info, param_info=self.param_info, data_info=self.data_info, shock_info=self.shock_info, state_info=self.state_info, component_info=self._component_info, measurement_error=self.measurement_error, - exog_names=self.exog_names, name_to_variable=self._name_to_variable, name_to_data=self._name_to_data, filter_type=filter_type, diff --git a/tests/statespace/models/structural/components/test_regression.py b/tests/statespace/models/structural/components/test_regression.py index ffde3348f..7af48e2f4 100644 --- a/tests/statespace/models/structural/components/test_regression.py +++ b/tests/statespace/models/structural/components/test_regression.py @@ -66,7 +66,7 @@ def test_exogenous_component(self, rng, regression_data, innovations): mod = mod.build(verbose=False) _assert_basic_coords_correct(mod) - assert mod.coords["state_exog"].labels == ["feature_1", "feature_2"] + assert mod.coords["state_exog"] == ["feature_1", "feature_2"] if innovations: # Check that sigma_beta parameter is included @@ -125,7 +125,7 @@ def test_regression_with_multiple_observed_states(self, rng, regression_data, in assert_allclose(x[0, 2:], params["beta_exog"][1], atol=ATOL, rtol=RTOL) mod = mod.build(verbose=False) - assert mod.coords["state_exog"].labels == ["feature_1", "feature_2"] + assert mod.coords["state_exog"] == ["feature_1", "feature_2"] Z = mod.ssm["design"].eval({"data_exog": regression_data}) vec_block_diag = np.vectorize(block_diag, signature="(n,m),(o,p)->(q,r)") @@ -164,8 +164,8 @@ def test_add_regression_components_with_multiple_observed_states( ) mod = (reg1 + reg2).build(verbose=False) - assert mod.coords["state_exog1"].labels == ["a", "b"] - assert mod.coords["state_exog2"].labels == ["c"] + assert mod.coords["state_exog1"] == ["a", "b"] + assert mod.coords["state_exog2"] == ["c"] Z = mod.ssm["design"].eval( { @@ -211,7 +211,7 @@ def test_filter_scans_time_varying_design_matrix(self, rng, time_series_data, in reg = st.RegressionComponent(state_names=["a", "b"], name="exog", innovations=innovations) mod = reg.build(verbose=False) - with pm.Model(coords=mod.coords.to_dict()) as m: + with pm.Model(coords=mod.coords) as m: data_exog = pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) @@ -249,8 +249,8 @@ def test_regression_multiple_shared_construction(): assert mod.k_states == 1 assert mod.k_posdef == 1 - assert mod.coords["state_regression"].labels == ["A"] - assert mod.coords["endog_regression"].labels == ["data_1", "data_2"] + assert mod.coords["state_regression"] == ["A"] + assert mod.coords["endog_regression"] == ["data_1", "data_2"] assert mod.state_names == ("A[regression_shared]",) @@ -291,6 +291,7 @@ def test_regression_multiple_shared_observed(rng): np.testing.assert_allclose(y[:, 0], y[:, 2]) +@pytest.mark.filterwarnings("ignore::UserWarning") def test_regression_mixed_shared_and_not_shared(): mod_1 = st.RegressionComponent( name="individual", diff --git a/tests/statespace/models/structural/conftest.py b/tests/statespace/models/structural/conftest.py index 15dac710d..a395b528c 100644 --- a/tests/statespace/models/structural/conftest.py +++ b/tests/statespace/models/structural/conftest.py @@ -19,11 +19,11 @@ def rng(): def _assert_basic_coords_correct(mod): - assert mod.coords[ALL_STATE_DIM].labels == mod.state_names - assert mod.coords[ALL_STATE_AUX_DIM].labels == mod.state_names - assert mod.coords[SHOCK_DIM].labels == mod.shock_names - assert mod.coords[SHOCK_AUX_DIM].labels == mod.shock_names - expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ["data"] + assert mod.coords[ALL_STATE_DIM] == mod.state_names + assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names + assert mod.coords[SHOCK_DIM] == mod.shock_names + assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names + expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ("data",) - assert mod.coords[OBS_STATE_DIM].labels == expected_obs - assert mod.coords[OBS_STATE_AUX_DIM].labels == expected_obs + assert mod.coords[OBS_STATE_DIM] == expected_obs + assert mod.coords[OBS_STATE_AUX_DIM] == expected_obs From 228acff251a27a748b02ff87b3a3fc6159564744 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sun, 16 Nov 2025 09:44:02 -0700 Subject: [PATCH 7/8] 1. added add and merge methods to base class 2. created tests for add and merge methods 3. added utility to convert from snake to pascal and integrated it in error messaging --- pymc_extras/statespace/core/properties.py | 92 +++++-------------- .../statespace/models/structural/core.py | 5 +- pymc_extras/statespace/utils/message_tools.py | 5 + tests/statespace/core/test_properties.py | 25 ++++- 4 files changed, 56 insertions(+), 71 deletions(-) create mode 100644 pymc_extras/statespace/utils/message_tools.py diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py index 35c1394d0..d01e90928 100644 --- a/pymc_extras/statespace/core/properties.py +++ b/pymc_extras/statespace/core/properties.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import warnings from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass, fields -from typing import Generic, Self, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from pymc_extras.statespace.core import PyMCStateSpace from pymc_extras.statespace.utils.constants import ( @@ -15,6 +17,9 @@ SHOCK_DIM, ) +if TYPE_CHECKING: + from pymc_extras.statespace.models.structural.core import Component + @dataclass(frozen=True) class Property: @@ -62,7 +67,7 @@ def __getitem__(self, key: str) -> T: def __contains__(self, key: object) -> bool: return key in self._index - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[T]: return iter(self.items) def __len__(self) -> int: @@ -71,11 +76,24 @@ def __len__(self) -> int: def __str__(self) -> str: return f"{self.key_field}s: {list(self._index.keys())}" + def add(self, new_item: T): + return type(self)([*self.items, new_item]) + + def merge(self, other: Self, allow_duplicates: bool = False) -> Self: + if not isinstance(other, type(self)): + raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}") + + overlapping = set(self.names) & set(other.names) + if overlapping and not allow_duplicates: + raise ValueError(f"Duplicate names found: {overlapping}") + + return type(self)(list(self.items) + list(other.items)) + @property def names(self) -> tuple[str, ...]: return tuple(self._index.keys()) - def copy(self) -> "Info[T]": + def copy(self) -> Info[T]: return deepcopy(self) @@ -92,21 +110,6 @@ class ParameterInfo(Info[Parameter]): def __init__(self, parameters: list[Parameter]): super().__init__(items=tuple(parameters), key_field="name") - def add(self, parameter: Parameter) -> "ParameterInfo": - # return a new ParameterInfo with parameter appended - return ParameterInfo(parameters=[*list(self.items), parameter]) - - def merge(self, other: "ParameterInfo", allow_duplicates: bool = False) -> "ParameterInfo": - """Combine parameters from two ParameterInfo objects.""" - if not isinstance(other, ParameterInfo): - raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo") - - overlapping = set(self.names) & set(other.names) - if overlapping and not allow_duplicates: - raise ValueError(f"Duplicate parameter names found: {overlapping}") - - return ParameterInfo(parameters=list(self.items) + list(other.items)) - @dataclass(frozen=True) class Data(Property): @@ -132,21 +135,6 @@ def exogenous_names(self) -> tuple[str, ...]: def __str__(self) -> str: return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}" - def add(self, data: Data) -> "DataInfo": - # return a new DataInfo with data appended - return DataInfo(data=[*list(self.items), data]) - - def merge(self, other: "DataInfo", allow_duplicates: bool = False) -> "DataInfo": - """Combine data from two DataInfo objects.""" - if not isinstance(other, DataInfo): - raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo") - - overlapping = set(self.names) & set(other.names) - if overlapping and not allow_duplicates: - raise ValueError(f"Duplicate data names found: {overlapping}") - - return DataInfo(data=list(self.items) + list(other.items)) - @dataclass(frozen=True) class Coord(Property): @@ -169,7 +157,7 @@ def __str__(self) -> str: @classmethod def default_coords_from_model( - cls, model: PyMCStateSpace + cls, model: PyMCStateSpace | Component ) -> ( Self ): # TODO: Need to figure out how to include Component type was causing circular import issues @@ -192,21 +180,6 @@ def default_coords_from_model( def to_dict(self): return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0} - def add(self, coord: Coord) -> "CoordInfo": - # return a new CoordInfo with data appended - return CoordInfo(coords=[*list(self.items), coord]) - - def merge(self, other: "CoordInfo", allow_duplicates: bool = False) -> "CoordInfo": - """Combine data from two CoordInfo objects.""" - if not isinstance(other, CoordInfo): - raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo") - - overlapping = set(self.names) & set(other.names) - if overlapping and not allow_duplicates: - raise ValueError(f"Duplicate coord names found: {overlapping}") - - return CoordInfo(coords=list(self.items) + list(other.items)) - @dataclass(frozen=True) class State(Property): @@ -237,11 +210,7 @@ def observed_state_names(self) -> tuple[State, ...]: def unobserved_state_names(self) -> tuple[State, ...]: return tuple(s.name for s in self.items if not s.observed) - def add(self, state: State) -> "StateInfo": - # return a new StateInfo with state appended - return StateInfo(states=[*list(self.items), state]) - - def merge(self, other: "StateInfo", allow_duplicates: bool = False) -> "StateInfo": + def merge(self, other: StateInfo, allow_duplicates: bool = False) -> StateInfo: """Combine states from two StateInfo objects.""" if not isinstance(other, StateInfo): raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo") @@ -270,18 +239,3 @@ class Shock(Property): class ShockInfo(Info[Shock]): def __init__(self, shocks: list[Shock]): super().__init__(items=tuple(shocks), key_field="name") - - def add(self, shock: Shock) -> "ShockInfo": - # return a new ShockInfo with shock appended - return ShockInfo(shocks=[*list(self.items), shock]) - - def merge(self, other: "ShockInfo", allow_duplicates: bool = False) -> "ShockInfo": - """Combine shocks from two ShockInfo objects.""" - if not isinstance(other, ShockInfo): - raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo") - - overlapping = set(self.names) & set(other.names) - if overlapping and not allow_duplicates: - raise ValueError(f"Duplicate shock names found: {overlapping}") - - return ShockInfo(shocks=list(self.items) + list(other.items)) diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 8d777bc88..160a22763 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -34,6 +34,7 @@ ALL_STATE_DIM, LONG_MATRIX_NAMES, ) +from pymc_extras.statespace.utils.message_tools import snake_to_pascal _log = logging.getLogger(__name__) floatX = config.floatX @@ -815,8 +816,10 @@ def _combine_property(self, other, name, allow_duplicates=True): ) if not is_dataclass(self_prop): + # TODO: This works right now because we are only passing _info info names into _combine_property + # If we don't follow that schema moving forward this will break. raise TypeError( - f"All component properties are expected to be dataclasses, but found {type(self_prop)}" + f"Component properties are expected to be {snake_to_pascal(name)}, but found {type(self_prop)}" f"for property {name} of {self} and {type(other_prop)} for {other}'" ) diff --git a/pymc_extras/statespace/utils/message_tools.py b/pymc_extras/statespace/utils/message_tools.py new file mode 100644 index 000000000..7f37affe0 --- /dev/null +++ b/pymc_extras/statespace/utils/message_tools.py @@ -0,0 +1,5 @@ +import re + + +def snake_to_pascal(s: str) -> str: + return re.sub(r"(?:^|_)([a-z])", lambda m: m.group(1).upper(), s) diff --git a/tests/statespace/core/test_properties.py b/tests/statespace/core/test_properties.py index 7f7cb8ae3..396731b93 100644 --- a/tests/statespace/core/test_properties.py +++ b/tests/statespace/core/test_properties.py @@ -68,7 +68,7 @@ def test_data_info_needs_exogenous_and_str(): def test_coord_info_make_defaults_from_component_and_types(): class DummyComponent: state_names = ["x1", "x2"] - observed_state_names = ["x2"] + observed_states = ["x2"] shock_names = ["eps1"] ci = CoordInfo.default_coords_from_model(DummyComponent()) @@ -117,3 +117,26 @@ def test_info_is_iterable_and_unpackable(): a, b = info.items assert a.name == "p1" and b.name == "p2" + + +def test_info_add_method(): + a_param = Parameter(name="a", shape=(1,), dims=("dim",)) + param_info = ParameterInfo(parameters=[a_param]) + + b_param = Parameter(name="b", shape=(1,), dims=("dim",)) + + new_param_info = param_info.add(new_item=b_param) + + assert new_param_info.names == ("a", "b") + + +def test_info_merge_method(): + a_param = Parameter(name="a", shape=(1,), dims=("dim",)) + a_param_info = ParameterInfo(parameters=[a_param]) + + b_param = Parameter(name="b", shape=(1,), dims=("dim",)) + b_param_info = ParameterInfo(parameters=[b_param]) + + new_param_info = a_param_info.merge(b_param_info) + + assert new_param_info.names == ("a", "b") From 1ae433f5261796a3ccd29a0bd25646ecbe994464 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 22 Nov 2025 07:50:36 -0700 Subject: [PATCH 8/8] removed data & coords setters in _set medthod in Component class and placed default shoch and state setters --- .../statespace/models/structural/core.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 160a22763..a7a1af043 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -683,17 +683,20 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: def _set_parameters(self) -> None: raise NotImplementedError - def _set_data(self) -> None: - raise NotImplementedError - def _set_shocks(self) -> None: - raise NotImplementedError + return ShockInfo(shocks=[Shock(name=f"shock_{n}") for n in range(self.k_posdef or 0)]) def _set_states(self) -> None: - raise NotImplementedError - - def _set_coords(self) -> None: - raise NotImplementedError + return StateInfo( + states=[ + State(name=n, observed=False, shared=self.share_states) + for n in range(self.k_states or 0) + ] + + [ + State(name=n, observed=True, shared=self.share_states) + for n in range(self.k_endog or 0) + ] + ) def make_symbolic_graph(self) -> None: raise NotImplementedError