From 7f4d02f0d8c7685896389eb7a6fe4309112d8603 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 25 Aug 2025 13:39:57 +0800 Subject: [PATCH 1/5] First pass on exogenous variables in VARMA --- pymc_extras/statespace/models/VARMAX.py | 165 ++++++++++++++++-------- tests/statespace/models/test_VARMAX.py | 146 +++++++++++++++++++++ 2 files changed, 257 insertions(+), 54 deletions(-) diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 79ad25b97..37d592bd3 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -28,60 +28,6 @@ class BayesianVARMAX(PyMCStateSpace): r""" Vector AutoRegressive Moving Average with eXogenous Regressors - Parameters - ---------- - order: tuple of (int, int) - Number of autoregressive (AR) and moving average (MA) terms to include in the model. All terms up to the - specified order are included. For restricted models, set zeros directly on the priors. - - endog_names: list of str, optional - Names of the endogenous variables being modeled. Used to generate names for the state and shock coords. If - None, the state names will simply be numbered. - - Exactly one of either ``endog_names`` or ``k_endog`` must be specified. - - k_endog: int, optional - Number of endogenous states to be modeled. - - Exactly one of either ``endog_names`` or ``k_endog`` must be specified. - - stationary_initialization: bool, default False - If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady - state values will be used. If False, the user is responsible for setting priors on the initial state and - initial covariance. - - ..warning :: This option is very sensitive to the priors placed on the AR and MA parameters. If the model dynamics - for a given sample are not stationary, sampling will fail with a "covariance is not positive semi-definite" - error. - - filter_type: str, default "standard" - The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state", - and "cholesky". See the docs for kalman filters for more details. - - state_structure: str, default "fast" - How to represent the state-space system. When "interpretable", each element of the state vector will have a - precise meaning as either lagged data, innovations, or lagged innovations. This comes at the cost of a larger - state vector, which may hurt performance. - - When "fast", states are combined to minimize the dimension of the state vector, but lags and innovations are - mixed together as a result. Only the first state (the modeled timeseries) will have an obvious interpretation - in this case. - - measurement_error: bool, default True - If true, a measurement error term is added to the model. - - verbose: bool, default True - If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports. - - mode: str or Mode, optional - Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and - ``forecast``. The mode does **not** effect calls to ``pm.sample``. - - Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument - to all sampling methods. - - Notes - ----- The VARMA model is a multivariate extension of the SARIMAX model. Given a set of timeseries :math:`\{x_t\}_{t=0}^T`, with :math:`x_t = \begin{bmatrix} x_{1,t} & x_{2,t} & \cdots & x_{k,t} \end{bmatrix}^T`, a VARMA models each series as a function of the histories of all series. Specifically, denoting the AR-MA order as (p, q), a VARMA can be @@ -152,12 +98,77 @@ def __init__( order: tuple[int, int], endog_names: list[str] | None = None, k_endog: int | None = None, + exog_state_names: list[str] | dict[str, list[str]] | None = None, + k_exog: int | dict[str, int] | None = None, stationary_initialization: bool = False, filter_type: str = "standard", measurement_error: bool = False, verbose: bool = True, mode: str | Mode | None = None, ): + """ + Create a Bayesian VARMAX model. + + Parameters + ---------- + order: tuple of (int, int) + Number of autoregressive (AR) and moving average (MA) terms to include in the model. All terms up to the + specified order are included. For restricted models, set zeros directly on the priors. + + endog_names: list of str, optional + Names of the endogenous variables being modeled. Used to generate names for the state and shock coords. If + None, the state names will simply be numbered. + + Exactly one of either ``endog_names`` or ``k_endog`` must be specified. + + exog_state_names : list[str] or dict[str, list[str]], optional + Names of the exogenous state variables. If a list, all endogenous variables will share the same exogenous + variables. If a dict, keys should be the names of the endogenous variables, and values should be lists of the + exogenous variable names for that endogenous variable. Endogenous variables not included in the dict will + be assumed to have no exogenous variables. If None, no exogenous variables will be included. + + k_exog : int or dict[str, int], optional + Number of exogenous variables. If an int, all endogenous variables will share the same number of exogenous + variables. If a dict, keys should be the names of the endogenous variables, and values should be the number of + exogenous variables for that endogenous variable. Endogenous variables not included in the dict will be + assumed to have no exogenous variables. If None, no exogenous variables will be included. + + stationary_initialization: bool, default False + If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady + state values will be used. If False, the user is responsible for setting priors on the initial state and + initial covariance. + + ..warning :: This option is very sensitive to the priors placed on the AR and MA parameters. If the model dynamics + for a given sample are not stationary, sampling will fail with a "covariance is not positive semi-definite" + error. + + filter_type: str, default "standard" + The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state", + and "cholesky". See the docs for kalman filters for more details. + + state_structure: str, default "fast" + How to represent the state-space system. When "interpretable", each element of the state vector will have a + precise meaning as either lagged data, innovations, or lagged innovations. This comes at the cost of a larger + state vector, which may hurt performance. + + When "fast", states are combined to minimize the dimension of the state vector, but lags and innovations are + mixed together as a result. Only the first state (the modeled timeseries) will have an obvious interpretation + in this case. + + measurement_error: bool, default True + If true, a measurement error term is added to the model. + + verbose: bool, default True + If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports. + + mode: str or Mode, optional + Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and + ``forecast``. The mode does **not** effect calls to ``pm.sample``. + + Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument + to all sampling methods. + + """ if (endog_names is None) and (k_endog is None): raise ValueError("Must specify either endog_names or k_endog") if (endog_names is not None) and (k_endog is None): @@ -168,6 +179,48 @@ def __init__( if len(endog_names) != k_endog: raise ValueError("Length of provided endog_names does not match provided k_endog") + if k_exog is not None and not isinstance(k_endog, int | dict): + raise ValueError("If not None, k_endog must be either an int or a dict") + if exog_state_names is not None and not isinstance(exog_state_names, list | dict): + raise ValueError("If not None, exog_state_names must be either a list or a dict") + + if k_exog is not None and exog_state_names is not None: + if isinstance(k_exog, int) and isinstance(exog_state_names, list): + if len(exog_state_names) != k_exog: + raise ValueError("Length of exog_state_names does not match provided k_exog") + elif isinstance(k_exog, int) and isinstance(exog_state_names, dict): + raise ValueError( + "If k_exog is an int, exog_state_names must be a list of the same length (or None)" + ) + elif isinstance(k_exog, dict) and isinstance(exog_state_names, list): + raise ValueError( + "If k_exog is a dict, exog_state_names must be a dict as well (or None)" + ) + elif isinstance(k_exog, dict) and isinstance(exog_state_names, dict): + if set(k_exog.keys()) != set(exog_state_names.keys()): + raise ValueError("Keys of k_exog and exog_state_names dicts must match") + if not all( + len(names) == k for names, k in zip(exog_state_names.values(), k_exog.values()) + ): + raise ValueError( + "If both k_endog and exog_state_names are provided, lengths of exog_state_names " + "lists must match corresponding values in k_exog" + ) + + if k_exog is not None and exog_state_names is None: + if isinstance(k_exog, int): + exog_state_names = [f"exog.{i + 1}" for i in range(k_exog)] + elif isinstance(k_exog, dict): + exog_state_names = { + name: [f"{name}.exog.{i + 1}" for i in range(k)] for name, k in k_exog.items() + } + + if k_exog is None and exog_state_names is not None: + if isinstance(exog_state_names, list): + k_exog = len(exog_state_names) + elif isinstance(exog_state_names, dict): + k_exog = {name: len(names) for name, names in exog_state_names.items()} + self.endog_names = list(endog_names) self.p, self.q = order self.stationary_initialization = stationary_initialization @@ -244,6 +297,10 @@ def param_info(self) -> dict[str, dict[str, Any]]: return {name: info[name] for name in self.param_names} + @property + def data_names(self) -> list[str]: + return self._exog_names + @property def state_names(self): state_names = self.endog_names.copy() diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index fbd0cfc04..ab15b5cbf 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -188,3 +188,149 @@ def test_impulse_response(parameters, varma_mod, idata, rng): irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters) assert not np.any(np.isnan(irf.irf.values)) + + +def test_create_varmax_with_exogenous(data): + # Case 1: k_exog as int, exog_state_names is None + mod = BayesianVARMAX( + k_endog=data.shape[1], + order=(1, 0), + k_exog=2, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == 2 + assert mod.exog_state_names == ["exog.1", "exog.2"] + + # Case 2: exog_state_names as list, k_exog is None + mod = BayesianVARMAX( + k_endog=data.shape[1], + order=(1, 0), + exog_state_names=["foo", "bar"], + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == 2 + assert mod.exog_state_names == ["foo", "bar"] + + # Case 3: k_exog as int, exog_state_names as list (matching) + mod = BayesianVARMAX( + k_endog=data.shape[1], + order=(1, 0), + k_exog=2, + exog_state_names=["a", "b"], + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == 2 + assert mod.exog_state_names == ["a", "b"] + + # Case 4: k_exog as dict, exog_state_names is None + k_exog = {"state.1": 2, "state.2": 1, "state.3": 0} + mod = BayesianVARMAX( + endog_names=["state.1", "state.2", "state.3"], + order=(1, 0), + k_exog=k_exog, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == k_exog + assert mod.exog_state_names == { + "state.1": ["state.1.exog.1", "state.1.exog.2"], + "state.2": ["state.2.exog.1"], + "state.3": [], + } + + # Case 5: exog_state_names as dict, k_exog is None + exog_state_names = {"state.1": ["a", "b"], "state.2": ["c"], "state.3": []} + mod = BayesianVARMAX( + endog_names=["state.1", "state.2", "state.3"], + order=(1, 0), + exog_state_names=exog_state_names, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == {"state.1": 2, "state.2": 1, "state.3": 0} + assert mod.exog_state_names == exog_state_names + + # Case 6: k_exog as dict, exog_state_names as dict (matching) + k_exog = {"state.1": 2, "state.2": 1} + exog_state_names = {"state.1": ["a", "b"], "state.2": ["c"]} + mod = BayesianVARMAX( + endog_names=["state.1", "state.2"], + order=(1, 0), + k_exog=k_exog, + exog_state_names=exog_state_names, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == k_exog + assert mod.exog_state_names == exog_state_names + + # Error: k_exog as int, exog_state_names as list (length mismatch) + with pytest.raises( + ValueError, match="Length of exog_state_names does not match provided k_exog" + ): + BayesianVARMAX( + k_endog=2, + order=(1, 0), + k_exog=3, + exog_state_names=["a", "b"], + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + + # Error: k_exog as int, exog_state_names as dict + with pytest.raises(ValueError): + BayesianVARMAX( + k_endog=2, + order=(1, 0), + k_exog=2, + exog_state_names={"state.1": ["a"], "state.2": ["b"]}, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + + # Error: k_exog as dict, exog_state_names as list + with pytest.raises(ValueError): + BayesianVARMAX( + endog_names=["state.1", "state.2"], + order=(1, 0), + k_exog={"state.1": 1, "state.2": 1}, + exog_state_names=["a", "b"], + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + + # Error: k_exog as dict, exog_state_names as dict (keys mismatch) + with pytest.raises(ValueError, match="Keys of k_exog and exog_state_names dicts must match"): + BayesianVARMAX( + endog_names=["state.1", "state.2"], + order=(1, 0), + k_exog={"state.1": 1, "state.2": 1}, + exog_state_names={"state.1": ["a"], "state.3": ["b"]}, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + + # Error: k_exog as dict, exog_state_names as dict (length mismatch) + with pytest.raises(ValueError, match="lengths of exog_state_names lists must match"): + BayesianVARMAX( + endog_names=["state.1", "state.2"], + order=(1, 0), + k_exog={"state.1": 2, "state.2": 1}, + exog_state_names={"state.1": ["a"], "state.2": ["b"]}, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) From ff44be594b5f2fd169d900fb2504d3d9c1fa6658 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 25 Aug 2025 14:01:23 +0800 Subject: [PATCH 2/5] Adjust state names for API consistency --- pymc_extras/statespace/models/VARMAX.py | 15 +++++-- tests/statespace/models/test_VARMAX.py | 56 +++++++++++++++---------- 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 37d592bd3..529b87819 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -174,7 +174,7 @@ def __init__( if (endog_names is not None) and (k_endog is None): k_endog = len(endog_names) if (endog_names is None) and (k_endog is not None): - endog_names = [f"state.{i + 1}" for i in range(k_endog)] + endog_names = [f"observed_{i}" for i in range(k_endog)] if (endog_names is not None) and (k_endog is not None): if len(endog_names) != k_endog: raise ValueError("Length of provided endog_names does not match provided k_endog") @@ -209,10 +209,10 @@ def __init__( if k_exog is not None and exog_state_names is None: if isinstance(k_exog, int): - exog_state_names = [f"exog.{i + 1}" for i in range(k_exog)] + exog_state_names = [f"exogenous_{i}" for i in range(k_exog)] elif isinstance(k_exog, dict): exog_state_names = { - name: [f"{name}.exog.{i + 1}" for i in range(k)] for name, k in k_exog.items() + name: [f"{name}_exogenous_{i}" for i in range(k)] for name, k in k_exog.items() } if k_exog is None and exog_state_names is not None: @@ -222,6 +222,9 @@ def __init__( k_exog = {name: len(names) for name, names in exog_state_names.items()} self.endog_names = list(endog_names) + self.exog_state_names = exog_state_names + + self.k_exog = k_exog self.p, self.q = order self.stationary_initialization = stationary_initialization @@ -299,7 +302,11 @@ def param_info(self) -> dict[str, dict[str, Any]]: @property def data_names(self) -> list[str]: - return self._exog_names + if isinstance(self.exog_state_names, list): + return ["exogenous_data"] + elif isinstance(self.exog_state_names, dict): + return [f"{endog_state}_exogenous_data" for endog_state in self.exog_state_names.keys()] + return [] @property def state_names(self): diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index ab15b5cbf..4daa4c39b 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -201,7 +201,8 @@ def test_create_varmax_with_exogenous(data): stationary_initialization=False, ) assert mod.k_exog == 2 - assert mod.exog_state_names == ["exog.1", "exog.2"] + assert mod.exog_state_names == ["exogenous_0", "exogenous_1"] + assert mod.data_names == ["exogenous_data"] # Case 2: exog_state_names as list, k_exog is None mod = BayesianVARMAX( @@ -214,6 +215,7 @@ def test_create_varmax_with_exogenous(data): ) assert mod.k_exog == 2 assert mod.exog_state_names == ["foo", "bar"] + assert mod.data_names == ["exogenous_data"] # Case 3: k_exog as int, exog_state_names as list (matching) mod = BayesianVARMAX( @@ -227,11 +229,12 @@ def test_create_varmax_with_exogenous(data): ) assert mod.k_exog == 2 assert mod.exog_state_names == ["a", "b"] + assert mod.data_names == ["exogenous_data"] # Case 4: k_exog as dict, exog_state_names is None - k_exog = {"state.1": 2, "state.2": 1, "state.3": 0} + k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0} mod = BayesianVARMAX( - endog_names=["state.1", "state.2", "state.3"], + endog_names=["observed_0", "observed_1", "observed_2"], order=(1, 0), k_exog=k_exog, verbose=False, @@ -240,29 +243,39 @@ def test_create_varmax_with_exogenous(data): ) assert mod.k_exog == k_exog assert mod.exog_state_names == { - "state.1": ["state.1.exog.1", "state.1.exog.2"], - "state.2": ["state.2.exog.1"], - "state.3": [], + "observed_0": ["observed_0_exogenous_0", "observed_0_exogenous_1"], + "observed_1": ["observed_1_exogenous_0"], + "observed_2": [], } + assert mod.data_names == [ + "observed_0_exogenous_data", + "observed_1_exogenous_data", + "observed_2_exogenous_data", + ] # Case 5: exog_state_names as dict, k_exog is None - exog_state_names = {"state.1": ["a", "b"], "state.2": ["c"], "state.3": []} + exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []} mod = BayesianVARMAX( - endog_names=["state.1", "state.2", "state.3"], + endog_names=["observed_0", "observed_1", "observed_2"], order=(1, 0), exog_state_names=exog_state_names, verbose=False, measurement_error=False, stationary_initialization=False, ) - assert mod.k_exog == {"state.1": 2, "state.2": 1, "state.3": 0} + assert mod.k_exog == {"observed_0": 2, "observed_1": 1, "observed_2": 0} assert mod.exog_state_names == exog_state_names + assert mod.data_names == [ + "observed_0_exogenous_data", + "observed_1_exogenous_data", + "observed_2_exogenous_data", + ] # Case 6: k_exog as dict, exog_state_names as dict (matching) - k_exog = {"state.1": 2, "state.2": 1} - exog_state_names = {"state.1": ["a", "b"], "state.2": ["c"]} + k_exog = {"observed_0": 2, "observed_1": 1} + exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"]} mod = BayesianVARMAX( - endog_names=["state.1", "state.2"], + endog_names=["observed_0", "observed_1"], order=(1, 0), k_exog=k_exog, exog_state_names=exog_state_names, @@ -272,6 +285,7 @@ def test_create_varmax_with_exogenous(data): ) assert mod.k_exog == k_exog assert mod.exog_state_names == exog_state_names + assert mod.data_names == ["observed_0_exogenous_data", "observed_1_exogenous_data"] # Error: k_exog as int, exog_state_names as list (length mismatch) with pytest.raises( @@ -293,7 +307,7 @@ def test_create_varmax_with_exogenous(data): k_endog=2, order=(1, 0), k_exog=2, - exog_state_names={"state.1": ["a"], "state.2": ["b"]}, + exog_state_names={"observed_0": ["a"], "observed_1": ["b"]}, verbose=False, measurement_error=False, stationary_initialization=False, @@ -302,9 +316,9 @@ def test_create_varmax_with_exogenous(data): # Error: k_exog as dict, exog_state_names as list with pytest.raises(ValueError): BayesianVARMAX( - endog_names=["state.1", "state.2"], + endog_names=["observed_0", "observed_1"], order=(1, 0), - k_exog={"state.1": 1, "state.2": 1}, + k_exog={"observed_0": 1, "observed_1": 1}, exog_state_names=["a", "b"], verbose=False, measurement_error=False, @@ -314,10 +328,10 @@ def test_create_varmax_with_exogenous(data): # Error: k_exog as dict, exog_state_names as dict (keys mismatch) with pytest.raises(ValueError, match="Keys of k_exog and exog_state_names dicts must match"): BayesianVARMAX( - endog_names=["state.1", "state.2"], + endog_names=["observed_0", "observed_1"], order=(1, 0), - k_exog={"state.1": 1, "state.2": 1}, - exog_state_names={"state.1": ["a"], "state.3": ["b"]}, + k_exog={"observed_0": 1, "observed_1": 1}, + exog_state_names={"observed_0": ["a"], "observed_2": ["b"]}, verbose=False, measurement_error=False, stationary_initialization=False, @@ -326,10 +340,10 @@ def test_create_varmax_with_exogenous(data): # Error: k_exog as dict, exog_state_names as dict (length mismatch) with pytest.raises(ValueError, match="lengths of exog_state_names lists must match"): BayesianVARMAX( - endog_names=["state.1", "state.2"], + endog_names=["observed_0", "observed_1"], order=(1, 0), - k_exog={"state.1": 2, "state.2": 1}, - exog_state_names={"state.1": ["a"], "state.2": ["b"]}, + k_exog={"observed_0": 2, "observed_1": 1}, + exog_state_names={"observed_0": ["a"], "observed_1": ["b"]}, verbose=False, measurement_error=False, stationary_initialization=False, From 604f77698dae38f0ead662eca989dbf25aa61e8a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 25 Aug 2025 21:58:41 +0800 Subject: [PATCH 3/5] Allow exogenous variables in BayesianVARMAX --- pymc_extras/statespace/models/VARMAX.py | 121 +++++++++++++++++- tests/statespace/models/test_VARMAX.py | 161 ++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 2 deletions(-) diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 529b87819..1450010bc 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -14,11 +14,13 @@ ALL_STATE_AUX_DIM, ALL_STATE_DIM, AR_PARAM_DIM, + EXOGENOUS_DIM, MA_PARAM_DIM, OBS_STATE_AUX_DIM, OBS_STATE_DIM, SHOCK_AUX_DIM, SHOCK_DIM, + TIME_DIM, ) floatX = pytensor.config.floatX @@ -264,6 +266,14 @@ def param_names(self): names.remove("ar_params") if self.q == 0: names.remove("ma_params") + + # Add exogenous regression coefficents rather than remove, since we might have to handle + # several (if self.exog_state_names is a dict) + if isinstance(self.exog_state_names, list): + names.append("beta_exog") + elif isinstance(self.exog_state_names, dict): + names.extend([f"beta_{name}" for name in self.exog_state_names.keys()]) + return names @property @@ -295,11 +305,49 @@ def param_info(self) -> dict[str, dict[str, Any]]: }, } + if isinstance(self.exog_state_names, list): + k_exog = len(self.exog_state_names) + info["beta_exog"] = { + "shape": (self.k_endog, k_exog), + "constraints": "None", + } + + elif isinstance(self.exog_state_names, dict): + for name, exog_names in self.exog_state_names.items(): + k_exog = len(exog_names) + info[f"beta_{name}"] = { + "shape": (k_exog,), + "constraints": "None", + } + for name in self.param_names: info[name]["dims"] = self.param_dims[name] return {name: info[name] for name in self.param_names} + @property + def data_info(self) -> dict[str, dict[str, Any]]: + info = None + + if isinstance(self.exog_state_names, list): + info = { + "exogenous_data": { + "dims": (TIME_DIM, EXOGENOUS_DIM), + "shape": (None, self.k_exog), + } + } + + elif isinstance(self.exog_state_names, dict): + info = { + f"{endog_state}_exogenous_data": { + "dims": (TIME_DIM, f"{EXOGENOUS_DIM}_{endog_state}"), + "shape": (None, len(exog_names)), + } + for endog_state, exog_names in self.exog_state_names.items() + } + + return info + @property def data_names(self) -> list[str]: if isinstance(self.exog_state_names, list): @@ -312,10 +360,10 @@ def data_names(self) -> list[str]: def state_names(self): state_names = self.endog_names.copy() state_names += [ - f"L{i + 1}.{state}" for i in range(self.p - 1) for state in self.endog_names + f"L{i + 1}_{state}" for i in range(self.p - 1) for state in self.endog_names ] state_names += [ - f"L{i + 1}.{state}_innov" for i in range(self.q) for state in self.endog_names + f"L{i + 1}_{state}_innov" for i in range(self.q) for state in self.endog_names ] return state_names @@ -340,6 +388,12 @@ def coords(self) -> dict[str, Sequence]: if self.q > 0: coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))}) + if isinstance(self.exog_state_names, list): + coords[EXOGENOUS_DIM] = self.exog_state_names + elif isinstance(self.exog_state_names, dict): + for name, exog_names in self.exog_state_names.items(): + coords[f"{EXOGENOUS_DIM}_{name}"] = exog_names + return coords @property @@ -363,6 +417,14 @@ def param_dims(self): del coord_map["P0"] del coord_map["x0"] + if isinstance(self.exog_state_names, list): + coord_map["beta_exog"] = (OBS_STATE_DIM, EXOGENOUS_DIM) + elif isinstance(self.exog_state_names, dict): + # If each state has its own exogenous variables, each parameter needs it own dim, since we expect the + # dim labels to all be different (otherwise we'd be in the list case). + for name in self.exog_state_names.keys(): + coord_map[f"beta_{name}"] = (f"{EXOGENOUS_DIM}_{name}",) + return coord_map def add_default_priors(self): @@ -450,6 +512,61 @@ def make_symbolic_graph(self) -> None: ) self.ssm["state_cov", :, :] = state_cov + if self.exog_state_names is not None: + if isinstance(self.exog_state_names, list): + beta_exog = self.make_and_register_variable( + "beta_exog", shape=(self.k_posdef, self.k_exog), dtype=floatX + ) + exog_data = self.make_and_register_data( + "exogenous_data", shape=(None, self.k_exog), dtype=floatX + ) + + obs_intercept = exog_data @ beta_exog.T + + elif isinstance(self.exog_state_names, dict): + obs_components = [] + for i, name in enumerate(self.endog_names): + if name in self.exog_state_names: + k_exog = len(self.exog_state_names[name]) + beta_exog = self.make_and_register_variable( + f"beta_{name}", shape=(k_exog,), dtype=floatX + ) + exog_data = self.make_and_register_data( + f"{name}_exogenous_data", shape=(None, k_exog), dtype=floatX + ) + obs_components.append(pt.expand_dims(exog_data @ beta_exog, axis=-1)) + else: + obs_components.append(pt.zeros((1, 1), dtype=floatX)) + + # TODO: Replace all of this with pt.concat_with_broadcast once PyMC works with pytensor >= 2.32 + + # If there were any zeros, they need to be broadcast against the non-zeros. + # Core shape is the last dim, the time dim is always broadcast + non_concat_shape = [1, None] + + # Look for the first non-zero component to get the shape from + for tensor_inp in obs_components: + for i, (bcast, sh) in enumerate( + zip(tensor_inp.type.broadcastable, tensor_inp.shape) + ): + if bcast or i == 1: + continue + non_concat_shape[i] = sh + + assert non_concat_shape.count(None) == 1 + + bcast_tensor_inputs = [] + for tensor_inp in obs_components: + non_concat_shape[1] = tensor_inp.shape[1] + bcast_tensor_inputs.append(pt.broadcast_to(tensor_inp, non_concat_shape)) + + obs_intercept = pt.join(1, *bcast_tensor_inputs) + + else: + raise NotImplementedError() + + self.ssm["obs_intercept"] = obs_intercept + if self.stationary_initialization: # Solve for matrix quadratic for P0 T = self.ssm["transition"] diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index 4daa4c39b..6742b75a5 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -9,6 +9,7 @@ import statsmodels.api as sm from numpy.testing import assert_allclose, assert_array_less +from pymc.model.transform.optimization import freeze_dims_and_data from pymc_extras.statespace import BayesianVARMAX from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG @@ -203,6 +204,10 @@ def test_create_varmax_with_exogenous(data): assert mod.k_exog == 2 assert mod.exog_state_names == ["exogenous_0", "exogenous_1"] assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["exogenous_0", "exogenous_1"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") # Case 2: exog_state_names as list, k_exog is None mod = BayesianVARMAX( @@ -216,6 +221,10 @@ def test_create_varmax_with_exogenous(data): assert mod.k_exog == 2 assert mod.exog_state_names == ["foo", "bar"] assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["foo", "bar"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") # Case 3: k_exog as int, exog_state_names as list (matching) mod = BayesianVARMAX( @@ -230,6 +239,10 @@ def test_create_varmax_with_exogenous(data): assert mod.k_exog == 2 assert mod.exog_state_names == ["a", "b"] assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["a", "b"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") # Case 4: k_exog as dict, exog_state_names is None k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0} @@ -252,6 +265,25 @@ def test_create_varmax_with_exogenous(data): "observed_1_exogenous_data", "observed_2_exogenous_data", ] + assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) + assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) + assert ( + "beta_observed_2" not in mod.param_dims + or mod.param_info.get("beta_observed_2") is None + or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0 + ) + + assert mod.coords["exogenous_observed_0"] == [ + "observed_0_exogenous_0", + "observed_0_exogenous_1", + ] + assert mod.coords["exogenous_observed_1"] == ["observed_1_exogenous_0"] + assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == [] + + assert mod.param_info["beta_observed_0"]["shape"] == (2,) + assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) + assert mod.param_info["beta_observed_1"]["shape"] == (1,) + assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) # Case 5: exog_state_names as dict, k_exog is None exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []} @@ -270,6 +302,22 @@ def test_create_varmax_with_exogenous(data): "observed_1_exogenous_data", "observed_2_exogenous_data", ] + assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) + assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) + assert ( + "beta_observed_2" not in mod.param_dims + or mod.param_info.get("beta_observed_2") is None + or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0 + ) + + assert mod.coords["exogenous_observed_0"] == ["a", "b"] + assert mod.coords["exogenous_observed_1"] == ["c"] + assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == [] + + assert mod.param_info["beta_observed_0"]["shape"] == (2,) + assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) + assert mod.param_info["beta_observed_1"]["shape"] == (1,) + assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) # Case 6: k_exog as dict, exog_state_names as dict (matching) k_exog = {"observed_0": 2, "observed_1": 1} @@ -286,6 +334,14 @@ def test_create_varmax_with_exogenous(data): assert mod.k_exog == k_exog assert mod.exog_state_names == exog_state_names assert mod.data_names == ["observed_0_exogenous_data", "observed_1_exogenous_data"] + assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) + assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) + assert mod.coords["exogenous_observed_0"] == ["a", "b"] + assert mod.coords["exogenous_observed_1"] == ["c"] + assert mod.param_info["beta_observed_0"]["shape"] == (2,) + assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) + assert mod.param_info["beta_observed_1"]["shape"] == (1,) + assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) # Error: k_exog as int, exog_state_names as list (length mismatch) with pytest.raises( @@ -348,3 +404,108 @@ def test_create_varmax_with_exogenous(data): measurement_error=False, stationary_initialization=False, ) + + +@pytest.mark.parametrize( + "k_exog, exog_state_names", + [ + (2, None), + (None, ["foo", "bar"]), + (None, {"y1": ["a", "b"], "y2": ["c"]}), + ], + ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"], +) +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_varmax_with_exog(rng, k_exog, exog_state_names): + endog_names = ["y1", "y2", "y3"] + n_obs = 50 + time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D") + + y = rng.normal(size=(n_obs, len(endog_names))) + df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX) + + if isinstance(exog_state_names, dict): + exog_data = { + f"{name}_exogenous_data": pd.DataFrame( + rng.normal(size=(n_obs, len(exog_names))).astype(floatX), + columns=exog_names, + index=time_idx, + ) + for name, exog_names in exog_state_names.items() + } + else: + exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)] + exog_data = { + "exogenous_data": pd.DataFrame( + rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX), + columns=exog_names, + index=time_idx, + ) + } + + mod = BayesianVARMAX( + endog_names=endog_names, + order=(1, 0), + k_exog=k_exog, + exog_state_names=exog_state_names, + verbose=True, + measurement_error=False, + stationary_initialization=False, + mode="JAX", + ) + + with pm.Model(coords=mod.coords) as m: + for var_name, data in exog_data.items(): + pm.Data(var_name, data, dims=mod.data_info[var_name]["dims"]) + + x0 = pm.Deterministic("x0", pt.zeros(mod.k_states), dims=mod.param_dims["x0"]) + P0_diag = pm.Exponential("P0_diag", 1.0, dims=mod.param_dims["P0"][0]) + P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=mod.param_dims["P0"]) + + ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=mod.param_dims["ar_params"]) + state_cov_diag = pm.Exponential("state_cov_diag", 1.0, dims=mod.param_dims["state_cov"][0]) + state_cov = pm.Deterministic( + "state_cov", pt.diag(state_cov_diag), dims=mod.param_dims["state_cov"] + ) + + # Exogenous priors + if isinstance(mod.exog_state_names, list): + beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=mod.param_dims["beta_exog"]) + elif isinstance(mod.exog_state_names, dict): + for name in mod.exog_state_names: + if mod.exog_state_names.get(name): + pm.Normal(f"beta_{name}", mu=0, sigma=1, dims=mod.param_dims[f"beta_{name}"]) + + mod.build_statespace_graph(data=df) + + with freeze_dims_and_data(m): + prior = pm.sample_prior_predictive( + draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"} + ) + + prior_cond = mod.sample_conditional_prior(prior, mvn_method="eigh") + beta_dot_data = prior_cond.filtered_prior_observed.values - prior_cond.filtered_prior.values + + if isinstance(exog_state_names, list) or k_exog is not None: + beta = prior.prior.beta_exog + assert beta.shape == (1, 10, 3, 2) + + np.testing.assert_allclose( + beta_dot_data, + np.einsum("tx,...sx->...ts", exog_data["exogenous_data"].values, beta), + atol=1e-2, + ) + + elif isinstance(exog_state_names, dict): + assert prior.prior.beta_y1.shape == (1, 10, 2) + assert prior.prior.beta_y2.shape == (1, 10, 1) + + obs_intercept = [ + np.einsum("tx,...x->...t", exog_data[f"{name}_exogenous_data"].values, beta) + for name, beta in zip(["y1", "y2"], [prior.prior.beta_y1, prior.prior.beta_y2]) + ] + + # y3 has no exogenous variables + obs_intercept.append(np.zeros_like(obs_intercept[0])) + + np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2) From 31cba0b056d616eb59575831821dc5c94408ffb0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 25 Aug 2025 22:47:53 +0800 Subject: [PATCH 4/5] Eagerly simplify model where possible --- pymc_extras/statespace/models/VARMAX.py | 10 ++++ tests/statespace/models/test_VARMAX.py | 62 +++++++++++++++++++------ 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 1450010bc..72f5b9013 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -223,6 +223,16 @@ def __init__( elif isinstance(exog_state_names, dict): k_exog = {name: len(names) for name, names in exog_state_names.items()} + # If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same + # then we can drop back to the list case. + if ( + isinstance(exog_state_names, dict) + and set(exog_state_names.keys()) == set(endog_names) + and len({frozenset(val) for val in exog_state_names.values()}) == 1 + ): + exog_state_names = exog_state_names[endog_names[0]] + k_exog = len(exog_state_names) + self.endog_names = list(endog_names) self.exog_state_names = exog_state_names diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index 6742b75a5..3dda9e943 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -191,8 +191,7 @@ def test_impulse_response(parameters, varma_mod, idata, rng): assert not np.any(np.isnan(irf.irf.values)) -def test_create_varmax_with_exogenous(data): - # Case 1: k_exog as int, exog_state_names is None +def test_create_varmax_with_exogenous_k_exog_int(data): mod = BayesianVARMAX( k_endog=data.shape[1], order=(1, 0), @@ -209,7 +208,8 @@ def test_create_varmax_with_exogenous(data): assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - # Case 2: exog_state_names as list, k_exog is None + +def test_create_varmax_with_exogenous_list_of_names(data): mod = BayesianVARMAX( k_endog=data.shape[1], order=(1, 0), @@ -226,7 +226,8 @@ def test_create_varmax_with_exogenous(data): assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - # Case 3: k_exog as int, exog_state_names as list (matching) + +def test_create_varmax_with_exogenous_both_defined_correctly(data): mod = BayesianVARMAX( k_endog=data.shape[1], order=(1, 0), @@ -244,7 +245,8 @@ def test_create_varmax_with_exogenous(data): assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - # Case 4: k_exog as dict, exog_state_names is None + +def test_create_varmax_with_exogenous_k_exog_dict(data): k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0} mod = BayesianVARMAX( endog_names=["observed_0", "observed_1", "observed_2"], @@ -285,7 +287,8 @@ def test_create_varmax_with_exogenous(data): assert mod.param_info["beta_observed_1"]["shape"] == (1,) assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) - # Case 5: exog_state_names as dict, k_exog is None + +def test_create_varmax_with_exogenous_exog_names_dict(data): exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []} mod = BayesianVARMAX( endog_names=["observed_0", "observed_1", "observed_2"], @@ -319,7 +322,8 @@ def test_create_varmax_with_exogenous(data): assert mod.param_info["beta_observed_1"]["shape"] == (1,) assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) - # Case 6: k_exog as dict, exog_state_names as dict (matching) + +def test_create_varmax_with_exogenous_both_dict_correct(data): k_exog = {"observed_0": 2, "observed_1": 1} exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"]} mod = BayesianVARMAX( @@ -343,7 +347,33 @@ def test_create_varmax_with_exogenous(data): assert mod.param_info["beta_observed_1"]["shape"] == (1,) assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) - # Error: k_exog as int, exog_state_names as list (length mismatch) + +def test_create_varmax_with_exogenous_dict_converts_to_list(data): + exog_state_names = { + "observed_0": ["a", "b"], + "observed_1": ["a", "b"], + "observed_2": ["a", "b"], + } + mod = BayesianVARMAX( + endog_names=["observed_0", "observed_1", "observed_2"], + order=(1, 0), + exog_state_names=exog_state_names, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + + assert mod.k_exog == 2 + assert mod.exog_state_names == ["a", "b"] + assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["a", "b"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") + + +def test_create_varmax_with_exogenous_raises_if_args_disagree(data): + # List case with pytest.raises( ValueError, match="Length of exog_state_names does not match provided k_exog" ): @@ -357,8 +387,10 @@ def test_create_varmax_with_exogenous(data): stationary_initialization=False, ) - # Error: k_exog as int, exog_state_names as dict - with pytest.raises(ValueError): + # Dict case + with pytest.raises( + ValueError, match="If k_exog is an int, exog_state_names must be a list of the same length" + ): BayesianVARMAX( k_endog=2, order=(1, 0), @@ -369,8 +401,10 @@ def test_create_varmax_with_exogenous(data): stationary_initialization=False, ) - # Error: k_exog as dict, exog_state_names as list - with pytest.raises(ValueError): + # dict + list + with pytest.raises( + ValueError, match="If k_exog is a dict, exog_state_names must be a dict as well" + ): BayesianVARMAX( endog_names=["observed_0", "observed_1"], order=(1, 0), @@ -381,7 +415,7 @@ def test_create_varmax_with_exogenous(data): stationary_initialization=False, ) - # Error: k_exog as dict, exog_state_names as dict (keys mismatch) + # Dict/dict, key mismatch with pytest.raises(ValueError, match="Keys of k_exog and exog_state_names dicts must match"): BayesianVARMAX( endog_names=["observed_0", "observed_1"], @@ -393,7 +427,7 @@ def test_create_varmax_with_exogenous(data): stationary_initialization=False, ) - # Error: k_exog as dict, exog_state_names as dict (length mismatch) + # Dict/dict, length mismatch with pytest.raises(ValueError, match="lengths of exog_state_names lists must match"): BayesianVARMAX( endog_names=["observed_0", "observed_1"], From caa62d2d6a40ee09637aabbf80e108fd9c14ffca Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 25 Aug 2025 23:06:03 +0800 Subject: [PATCH 5/5] Typo fix --- pymc_extras/statespace/models/VARMAX.py | 2 +- tests/statespace/models/test_VARMAX.py | 610 ++++++++++++------------ 2 files changed, 306 insertions(+), 306 deletions(-) diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 72f5b9013..1c02ef816 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -181,7 +181,7 @@ def __init__( if len(endog_names) != k_endog: raise ValueError("Length of provided endog_names does not match provided k_endog") - if k_exog is not None and not isinstance(k_endog, int | dict): + if k_exog is not None and not isinstance(k_exog, int | dict): raise ValueError("If not None, k_endog must be either an int or a dict") if exog_state_names is not None and not isinstance(exog_state_names, list | dict): raise ValueError("If not None, exog_state_names must be either a list or a dict") diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index 3dda9e943..52c5542f9 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -191,355 +191,355 @@ def test_impulse_response(parameters, varma_mod, idata, rng): assert not np.any(np.isnan(irf.irf.values)) -def test_create_varmax_with_exogenous_k_exog_int(data): - mod = BayesianVARMAX( - k_endog=data.shape[1], - order=(1, 0), - k_exog=2, - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - assert mod.k_exog == 2 - assert mod.exog_state_names == ["exogenous_0", "exogenous_1"] - assert mod.data_names == ["exogenous_data"] - assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") - assert mod.coords["exogenous"] == ["exogenous_0", "exogenous_1"] - assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) - assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - - -def test_create_varmax_with_exogenous_list_of_names(data): - mod = BayesianVARMAX( - k_endog=data.shape[1], - order=(1, 0), - exog_state_names=["foo", "bar"], - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - assert mod.k_exog == 2 - assert mod.exog_state_names == ["foo", "bar"] - assert mod.data_names == ["exogenous_data"] - assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") - assert mod.coords["exogenous"] == ["foo", "bar"] - assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) - assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - - -def test_create_varmax_with_exogenous_both_defined_correctly(data): - mod = BayesianVARMAX( - k_endog=data.shape[1], - order=(1, 0), - k_exog=2, - exog_state_names=["a", "b"], - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - assert mod.k_exog == 2 - assert mod.exog_state_names == ["a", "b"] - assert mod.data_names == ["exogenous_data"] - assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") - assert mod.coords["exogenous"] == ["a", "b"] - assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) - assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - - -def test_create_varmax_with_exogenous_k_exog_dict(data): - k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0} - mod = BayesianVARMAX( - endog_names=["observed_0", "observed_1", "observed_2"], - order=(1, 0), - k_exog=k_exog, - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - assert mod.k_exog == k_exog - assert mod.exog_state_names == { - "observed_0": ["observed_0_exogenous_0", "observed_0_exogenous_1"], - "observed_1": ["observed_1_exogenous_0"], - "observed_2": [], - } - assert mod.data_names == [ - "observed_0_exogenous_data", - "observed_1_exogenous_data", - "observed_2_exogenous_data", - ] - assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) - assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) - assert ( - "beta_observed_2" not in mod.param_dims - or mod.param_info.get("beta_observed_2") is None - or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0 - ) - - assert mod.coords["exogenous_observed_0"] == [ - "observed_0_exogenous_0", - "observed_0_exogenous_1", - ] - assert mod.coords["exogenous_observed_1"] == ["observed_1_exogenous_0"] - assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == [] - - assert mod.param_info["beta_observed_0"]["shape"] == (2,) - assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) - assert mod.param_info["beta_observed_1"]["shape"] == (1,) - assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) - - -def test_create_varmax_with_exogenous_exog_names_dict(data): - exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []} - mod = BayesianVARMAX( - endog_names=["observed_0", "observed_1", "observed_2"], - order=(1, 0), - exog_state_names=exog_state_names, - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - assert mod.k_exog == {"observed_0": 2, "observed_1": 1, "observed_2": 0} - assert mod.exog_state_names == exog_state_names - assert mod.data_names == [ - "observed_0_exogenous_data", - "observed_1_exogenous_data", - "observed_2_exogenous_data", - ] - assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) - assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) - assert ( - "beta_observed_2" not in mod.param_dims - or mod.param_info.get("beta_observed_2") is None - or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0 - ) - - assert mod.coords["exogenous_observed_0"] == ["a", "b"] - assert mod.coords["exogenous_observed_1"] == ["c"] - assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == [] - - assert mod.param_info["beta_observed_0"]["shape"] == (2,) - assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) - assert mod.param_info["beta_observed_1"]["shape"] == (1,) - assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) - - -def test_create_varmax_with_exogenous_both_dict_correct(data): - k_exog = {"observed_0": 2, "observed_1": 1} - exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"]} - mod = BayesianVARMAX( - endog_names=["observed_0", "observed_1"], - order=(1, 0), - k_exog=k_exog, - exog_state_names=exog_state_names, - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - assert mod.k_exog == k_exog - assert mod.exog_state_names == exog_state_names - assert mod.data_names == ["observed_0_exogenous_data", "observed_1_exogenous_data"] - assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) - assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) - assert mod.coords["exogenous_observed_0"] == ["a", "b"] - assert mod.coords["exogenous_observed_1"] == ["c"] - assert mod.param_info["beta_observed_0"]["shape"] == (2,) - assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) - assert mod.param_info["beta_observed_1"]["shape"] == (1,) - assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) - - -def test_create_varmax_with_exogenous_dict_converts_to_list(data): - exog_state_names = { - "observed_0": ["a", "b"], - "observed_1": ["a", "b"], - "observed_2": ["a", "b"], - } - mod = BayesianVARMAX( - endog_names=["observed_0", "observed_1", "observed_2"], - order=(1, 0), - exog_state_names=exog_state_names, - verbose=False, - measurement_error=False, - stationary_initialization=False, - ) - - assert mod.k_exog == 2 - assert mod.exog_state_names == ["a", "b"] - assert mod.data_names == ["exogenous_data"] - assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") - assert mod.coords["exogenous"] == ["a", "b"] - assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) - assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") - - -def test_create_varmax_with_exogenous_raises_if_args_disagree(data): - # List case - with pytest.raises( - ValueError, match="Length of exog_state_names does not match provided k_exog" - ): - BayesianVARMAX( - k_endog=2, +class TestVARMAXWithExogenous: + def test_create_varmax_with_exogenous_k_exog_int(self, data): + mod = BayesianVARMAX( + k_endog=data.shape[1], order=(1, 0), - k_exog=3, - exog_state_names=["a", "b"], + k_exog=2, verbose=False, measurement_error=False, stationary_initialization=False, ) - - # Dict case - with pytest.raises( - ValueError, match="If k_exog is an int, exog_state_names must be a list of the same length" - ): - BayesianVARMAX( - k_endog=2, + assert mod.k_exog == 2 + assert mod.exog_state_names == ["exogenous_0", "exogenous_1"] + assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["exogenous_0", "exogenous_1"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") + + def test_create_varmax_with_exogenous_list_of_names(self, data): + mod = BayesianVARMAX( + k_endog=data.shape[1], order=(1, 0), - k_exog=2, - exog_state_names={"observed_0": ["a"], "observed_1": ["b"]}, + exog_state_names=["foo", "bar"], verbose=False, measurement_error=False, stationary_initialization=False, ) - - # dict + list - with pytest.raises( - ValueError, match="If k_exog is a dict, exog_state_names must be a dict as well" - ): - BayesianVARMAX( - endog_names=["observed_0", "observed_1"], + assert mod.k_exog == 2 + assert mod.exog_state_names == ["foo", "bar"] + assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["foo", "bar"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") + + def test_create_varmax_with_exogenous_both_defined_correctly(self, data): + mod = BayesianVARMAX( + k_endog=data.shape[1], order=(1, 0), - k_exog={"observed_0": 1, "observed_1": 1}, + k_exog=2, exog_state_names=["a", "b"], verbose=False, measurement_error=False, stationary_initialization=False, ) + assert mod.k_exog == 2 + assert mod.exog_state_names == ["a", "b"] + assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["a", "b"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") + + def test_create_varmax_with_exogenous_k_exog_dict(self, data): + k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0} + mod = BayesianVARMAX( + endog_names=["observed_0", "observed_1", "observed_2"], + order=(1, 0), + k_exog=k_exog, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == k_exog + assert mod.exog_state_names == { + "observed_0": ["observed_0_exogenous_0", "observed_0_exogenous_1"], + "observed_1": ["observed_1_exogenous_0"], + "observed_2": [], + } + assert mod.data_names == [ + "observed_0_exogenous_data", + "observed_1_exogenous_data", + "observed_2_exogenous_data", + ] + assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) + assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) + assert ( + "beta_observed_2" not in mod.param_dims + or mod.param_info.get("beta_observed_2") is None + or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0 + ) - # Dict/dict, key mismatch - with pytest.raises(ValueError, match="Keys of k_exog and exog_state_names dicts must match"): - BayesianVARMAX( - endog_names=["observed_0", "observed_1"], + assert mod.coords["exogenous_observed_0"] == [ + "observed_0_exogenous_0", + "observed_0_exogenous_1", + ] + assert mod.coords["exogenous_observed_1"] == ["observed_1_exogenous_0"] + assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == [] + + assert mod.param_info["beta_observed_0"]["shape"] == (2,) + assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) + assert mod.param_info["beta_observed_1"]["shape"] == (1,) + assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) + + def test_create_varmax_with_exogenous_exog_names_dict(self, data): + exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []} + mod = BayesianVARMAX( + endog_names=["observed_0", "observed_1", "observed_2"], order=(1, 0), - k_exog={"observed_0": 1, "observed_1": 1}, - exog_state_names={"observed_0": ["a"], "observed_2": ["b"]}, + exog_state_names=exog_state_names, verbose=False, measurement_error=False, stationary_initialization=False, ) + assert mod.k_exog == {"observed_0": 2, "observed_1": 1, "observed_2": 0} + assert mod.exog_state_names == exog_state_names + assert mod.data_names == [ + "observed_0_exogenous_data", + "observed_1_exogenous_data", + "observed_2_exogenous_data", + ] + assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) + assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) + assert ( + "beta_observed_2" not in mod.param_dims + or mod.param_info.get("beta_observed_2") is None + or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0 + ) + + assert mod.coords["exogenous_observed_0"] == ["a", "b"] + assert mod.coords["exogenous_observed_1"] == ["c"] + assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == [] - # Dict/dict, length mismatch - with pytest.raises(ValueError, match="lengths of exog_state_names lists must match"): - BayesianVARMAX( + assert mod.param_info["beta_observed_0"]["shape"] == (2,) + assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) + assert mod.param_info["beta_observed_1"]["shape"] == (1,) + assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) + + def test_create_varmax_with_exogenous_both_dict_correct(self, data): + k_exog = {"observed_0": 2, "observed_1": 1} + exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"]} + mod = BayesianVARMAX( endog_names=["observed_0", "observed_1"], order=(1, 0), - k_exog={"observed_0": 2, "observed_1": 1}, - exog_state_names={"observed_0": ["a"], "observed_1": ["b"]}, + k_exog=k_exog, + exog_state_names=exog_state_names, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) + assert mod.k_exog == k_exog + assert mod.exog_state_names == exog_state_names + assert mod.data_names == ["observed_0_exogenous_data", "observed_1_exogenous_data"] + assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",) + assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",) + assert mod.coords["exogenous_observed_0"] == ["a", "b"] + assert mod.coords["exogenous_observed_1"] == ["c"] + assert mod.param_info["beta_observed_0"]["shape"] == (2,) + assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",) + assert mod.param_info["beta_observed_1"]["shape"] == (1,) + assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",) + + def test_create_varmax_with_exogenous_dict_converts_to_list(self, data): + exog_state_names = { + "observed_0": ["a", "b"], + "observed_1": ["a", "b"], + "observed_2": ["a", "b"], + } + mod = BayesianVARMAX( + endog_names=["observed_0", "observed_1", "observed_2"], + order=(1, 0), + exog_state_names=exog_state_names, verbose=False, measurement_error=False, stationary_initialization=False, ) - -@pytest.mark.parametrize( - "k_exog, exog_state_names", - [ - (2, None), - (None, ["foo", "bar"]), - (None, {"y1": ["a", "b"], "y2": ["c"]}), - ], - ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"], -) -@pytest.mark.filterwarnings("ignore::UserWarning") -def test_varmax_with_exog(rng, k_exog, exog_state_names): - endog_names = ["y1", "y2", "y3"] - n_obs = 50 - time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D") - - y = rng.normal(size=(n_obs, len(endog_names))) - df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX) - - if isinstance(exog_state_names, dict): - exog_data = { - f"{name}_exogenous_data": pd.DataFrame( - rng.normal(size=(n_obs, len(exog_names))).astype(floatX), - columns=exog_names, - index=time_idx, + assert mod.k_exog == 2 + assert mod.exog_state_names == ["a", "b"] + assert mod.data_names == ["exogenous_data"] + assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous") + assert mod.coords["exogenous"] == ["a", "b"] + assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2) + assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous") + + def test_create_varmax_with_exogenous_raises_if_args_disagree(self, data): + # List case + with pytest.raises( + ValueError, match="Length of exog_state_names does not match provided k_exog" + ): + BayesianVARMAX( + k_endog=2, + order=(1, 0), + k_exog=3, + exog_state_names=["a", "b"], + verbose=False, + measurement_error=False, + stationary_initialization=False, ) - for name, exog_names in exog_state_names.items() - } - else: - exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)] - exog_data = { - "exogenous_data": pd.DataFrame( - rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX), - columns=exog_names, - index=time_idx, + + # Dict case + with pytest.raises( + ValueError, + match="If k_exog is an int, exog_state_names must be a list of the same length", + ): + BayesianVARMAX( + k_endog=2, + order=(1, 0), + k_exog=2, + exog_state_names={"observed_0": ["a"], "observed_1": ["b"]}, + verbose=False, + measurement_error=False, + stationary_initialization=False, ) - } - mod = BayesianVARMAX( - endog_names=endog_names, - order=(1, 0), - k_exog=k_exog, - exog_state_names=exog_state_names, - verbose=True, - measurement_error=False, - stationary_initialization=False, - mode="JAX", - ) + # dict + list + with pytest.raises( + ValueError, match="If k_exog is a dict, exog_state_names must be a dict as well" + ): + BayesianVARMAX( + endog_names=["observed_0", "observed_1"], + order=(1, 0), + k_exog={"observed_0": 1, "observed_1": 1}, + exog_state_names=["a", "b"], + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) - with pm.Model(coords=mod.coords) as m: - for var_name, data in exog_data.items(): - pm.Data(var_name, data, dims=mod.data_info[var_name]["dims"]) + # Dict/dict, key mismatch + with pytest.raises( + ValueError, match="Keys of k_exog and exog_state_names dicts must match" + ): + BayesianVARMAX( + endog_names=["observed_0", "observed_1"], + order=(1, 0), + k_exog={"observed_0": 1, "observed_1": 1}, + exog_state_names={"observed_0": ["a"], "observed_2": ["b"]}, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) - x0 = pm.Deterministic("x0", pt.zeros(mod.k_states), dims=mod.param_dims["x0"]) - P0_diag = pm.Exponential("P0_diag", 1.0, dims=mod.param_dims["P0"][0]) - P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=mod.param_dims["P0"]) + # Dict/dict, length mismatch + with pytest.raises(ValueError, match="lengths of exog_state_names lists must match"): + BayesianVARMAX( + endog_names=["observed_0", "observed_1"], + order=(1, 0), + k_exog={"observed_0": 2, "observed_1": 1}, + exog_state_names={"observed_0": ["a"], "observed_1": ["b"]}, + verbose=False, + measurement_error=False, + stationary_initialization=False, + ) - ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=mod.param_dims["ar_params"]) - state_cov_diag = pm.Exponential("state_cov_diag", 1.0, dims=mod.param_dims["state_cov"][0]) - state_cov = pm.Deterministic( - "state_cov", pt.diag(state_cov_diag), dims=mod.param_dims["state_cov"] + @pytest.mark.parametrize( + "k_exog, exog_state_names", + [ + (2, None), + (None, ["foo", "bar"]), + (None, {"y1": ["a", "b"], "y2": ["c"]}), + ], + ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"], + ) + @pytest.mark.filterwarnings("ignore::UserWarning") + def test_varmax_with_exog(self, rng, k_exog, exog_state_names): + endog_names = ["y1", "y2", "y3"] + n_obs = 50 + time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D") + + y = rng.normal(size=(n_obs, len(endog_names))) + df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX) + + if isinstance(exog_state_names, dict): + exog_data = { + f"{name}_exogenous_data": pd.DataFrame( + rng.normal(size=(n_obs, len(exog_names))).astype(floatX), + columns=exog_names, + index=time_idx, + ) + for name, exog_names in exog_state_names.items() + } + else: + exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)] + exog_data = { + "exogenous_data": pd.DataFrame( + rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX), + columns=exog_names, + index=time_idx, + ) + } + + mod = BayesianVARMAX( + endog_names=endog_names, + order=(1, 0), + k_exog=k_exog, + exog_state_names=exog_state_names, + verbose=False, + measurement_error=False, + stationary_initialization=False, + mode="JAX", ) - # Exogenous priors - if isinstance(mod.exog_state_names, list): - beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=mod.param_dims["beta_exog"]) - elif isinstance(mod.exog_state_names, dict): - for name in mod.exog_state_names: - if mod.exog_state_names.get(name): - pm.Normal(f"beta_{name}", mu=0, sigma=1, dims=mod.param_dims[f"beta_{name}"]) + with pm.Model(coords=mod.coords) as m: + for var_name, data in exog_data.items(): + pm.Data(var_name, data, dims=mod.data_info[var_name]["dims"]) - mod.build_statespace_graph(data=df) + x0 = pm.Deterministic("x0", pt.zeros(mod.k_states), dims=mod.param_dims["x0"]) + P0_diag = pm.Exponential("P0_diag", 1.0, dims=mod.param_dims["P0"][0]) + P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=mod.param_dims["P0"]) - with freeze_dims_and_data(m): - prior = pm.sample_prior_predictive( - draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"} - ) + ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=mod.param_dims["ar_params"]) + state_cov_diag = pm.Exponential( + "state_cov_diag", 1.0, dims=mod.param_dims["state_cov"][0] + ) + state_cov = pm.Deterministic( + "state_cov", pt.diag(state_cov_diag), dims=mod.param_dims["state_cov"] + ) - prior_cond = mod.sample_conditional_prior(prior, mvn_method="eigh") - beta_dot_data = prior_cond.filtered_prior_observed.values - prior_cond.filtered_prior.values + # Exogenous priors + if isinstance(mod.exog_state_names, list): + beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=mod.param_dims["beta_exog"]) + elif isinstance(mod.exog_state_names, dict): + for name in mod.exog_state_names: + if mod.exog_state_names.get(name): + pm.Normal( + f"beta_{name}", mu=0, sigma=1, dims=mod.param_dims[f"beta_{name}"] + ) + + mod.build_statespace_graph(data=df) + + with freeze_dims_and_data(m): + prior = pm.sample_prior_predictive( + draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"} + ) - if isinstance(exog_state_names, list) or k_exog is not None: - beta = prior.prior.beta_exog - assert beta.shape == (1, 10, 3, 2) + prior_cond = mod.sample_conditional_prior(prior, mvn_method="eigh") + beta_dot_data = prior_cond.filtered_prior_observed.values - prior_cond.filtered_prior.values - np.testing.assert_allclose( - beta_dot_data, - np.einsum("tx,...sx->...ts", exog_data["exogenous_data"].values, beta), - atol=1e-2, - ) + if isinstance(exog_state_names, list) or k_exog is not None: + beta = prior.prior.beta_exog + assert beta.shape == (1, 10, 3, 2) - elif isinstance(exog_state_names, dict): - assert prior.prior.beta_y1.shape == (1, 10, 2) - assert prior.prior.beta_y2.shape == (1, 10, 1) + np.testing.assert_allclose( + beta_dot_data, + np.einsum("tx,...sx->...ts", exog_data["exogenous_data"].values, beta), + atol=1e-2, + ) - obs_intercept = [ - np.einsum("tx,...x->...t", exog_data[f"{name}_exogenous_data"].values, beta) - for name, beta in zip(["y1", "y2"], [prior.prior.beta_y1, prior.prior.beta_y2]) - ] + elif isinstance(exog_state_names, dict): + assert prior.prior.beta_y1.shape == (1, 10, 2) + assert prior.prior.beta_y2.shape == (1, 10, 1) + + obs_intercept = [ + np.einsum("tx,...x->...t", exog_data[f"{name}_exogenous_data"].values, beta) + for name, beta in zip(["y1", "y2"], [prior.prior.beta_y1, prior.prior.beta_y2]) + ] - # y3 has no exogenous variables - obs_intercept.append(np.zeros_like(obs_intercept[0])) + # y3 has no exogenous variables + obs_intercept.append(np.zeros_like(obs_intercept[0])) - np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2) + np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2)