Skip to content

Commit a27b47b

Browse files
Simplify how exog data is stored
1 parent cccfcad commit a27b47b

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ def __init__(
233233
self._fit_coords: dict[str, Sequence[str]] | None = None
234234
self._fit_dims: dict[str, Sequence[str]] | None = None
235235
self._fit_data: pt.TensorVariable | None = None
236+
self._fit_exog_data: dict[str, dict] = {}
236237

237238
self._needs_exog_data = None
238-
self._exog_names = []
239-
self._exog_data_info = {}
240239
self._name_to_variable = {}
241240
self._name_to_data = {}
242241

@@ -671,7 +670,7 @@ def _save_exogenous_data_info(self):
671670
pymc_mod = modelcontext(None)
672671
for data_name in self.data_names:
673672
data = pymc_mod[data_name]
674-
self._exog_data_info[data_name] = {
673+
self._fit_exog_data[data_name] = {
675674
"name": data_name,
676675
"value": data.get_value(),
677676
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
@@ -685,7 +684,7 @@ def _insert_random_variables(self):
685684
--------
686685
.. code:: python
687686
688-
ss_mod = pmss.BayesianSARIMA(order=(2, 0, 2), verbose=False, stationary_initialization=True)
687+
ss_mod = pmss.BayesianSARIMAX(order=(2, 0, 2), verbose=False, stationary_initialization=True)
689688
with pm.Model():
690689
x0 = pm.Normal('x0', size=ss_mod.k_states)
691690
ar_params = pm.Normal('ar_params', size=ss_mod.p)
@@ -1082,7 +1081,7 @@ def _kalman_filter_outputs_from_dummy_graph(
10821081

10831082
for name in self.data_names:
10841083
if name not in pm_mod:
1085-
pm.Data(**self._exog_data_info[name])
1084+
pm.Data(**self._fit_exog_data[name])
10861085

10871086
self._insert_data_variables()
10881087

@@ -1229,7 +1228,7 @@ def _sample_conditional(
12291228
method=mvn_method,
12301229
)
12311230

1232-
obs_mu = (Z @ mu[..., None]).squeeze(-1)
1231+
obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
12331232
obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H
12341233

12351234
SequenceMvNormal(
@@ -1351,7 +1350,7 @@ def _sample_unconditional(
13511350
self._insert_random_variables()
13521351

13531352
for name in self.data_names:
1354-
pm.Data(**self._exog_data_info[name])
1353+
pm.Data(**self._fit_exog_data[name])
13551354

13561355
self._insert_data_variables()
13571356

@@ -1651,7 +1650,7 @@ def sample_statespace_matrices(
16511650
self._insert_random_variables()
16521651

16531652
for name in self.data_names:
1654-
pm.Data(**self._exog_data_info[name])
1653+
pm.Data(**self.data_info[name])
16551654

16561655
self._insert_data_variables()
16571656
matrices = self.unpack_statespace()
@@ -1703,7 +1702,7 @@ def sample_filter_outputs(
17031702

17041703
if self.data_names:
17051704
for name in self.data_names:
1706-
pm.Data(**self._exog_data_info[name])
1705+
pm.Data(**self._fit_exog_data[name])
17071706

17081707
self._insert_data_variables()
17091708

@@ -1846,7 +1845,7 @@ def _validate_scenario_data(
18461845
}
18471846

18481847
if self._needs_exog_data and scenario is None:
1849-
exog_str = ",".join(self._exog_names)
1848+
exog_str = ",".join(self.data_names)
18501849
suffix = "s" if len(exog_str) > 1 else ""
18511850
raise ValueError(
18521851
f"This model was fit using exogenous data. Forecasting cannot be performed without "
@@ -1855,7 +1854,7 @@ def _validate_scenario_data(
18551854

18561855
if isinstance(scenario, dict):
18571856
for name, data in scenario.items():
1858-
if name not in self._exog_names:
1857+
if name not in self.data_names:
18591858
raise ValueError(
18601859
f"Scenario data provided for variable '{name}', which is not an exogenous variable "
18611860
f"used to fit the model."
@@ -1896,12 +1895,12 @@ def _validate_scenario_data(
18961895
# name should only be None on the first non-recursive call. We only arrive to this branch in that case
18971896
# if a non-dictionary was passed, which in turn should only happen if only a single exogenous data
18981897
# needs to be set.
1899-
if len(self._exog_names) > 1:
1898+
if len(self.data_names) > 1:
19001899
raise ValueError(
19011900
"Multiple exogenous variables were used to fit the model. Provide a dictionary of "
19021901
"scenario data instead."
19031902
)
1904-
name = self._exog_names[0]
1903+
name = self.data_names[0]
19051904

19061905
# Omit dataframe from this basic shape check so we can give more detailed information about missing columns
19071906
# in the next check
@@ -2103,7 +2102,7 @@ def _finalize_scenario_initialization(
21032102
return scenario
21042103

21052104
# This was already checked as valid
2106-
name = self._exog_names[0] if name is None else name
2105+
name = self.data_names[0] if name is None else name
21072106

21082107
# Small tidying up in the case we just have a single scenario that's already a dataframe.
21092108
if isinstance(scenario, pd.DataFrame | pd.Series):

tests/statespace/core/test_statespace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Sequence
44
from functools import partial
5+
from typing import Any
56

67
import numpy as np
78
import pandas as pd
@@ -44,9 +45,13 @@ def make_symbolic_graph(self):
4445
pass
4546

4647
@property
47-
def data_info(self):
48+
def data_info(self) -> dict[str, dict[str, Any]]:
4849
return data_info
4950

51+
@property
52+
def data_names(self) -> list[str]:
53+
return list(data_info.keys()) if data_info is not None else []
54+
5055
ss = StateSpace(
5156
k_states=k_states,
5257
k_endog=k_endog,
@@ -55,7 +60,6 @@ def data_info(self):
5560
verbose=verbose,
5661
)
5762
ss._needs_exog_data = data_info is not None
58-
ss._exog_names = list(data_info.keys()) if data_info is not None else []
5963

6064
return ss
6165

0 commit comments

Comments
 (0)