Skip to content

Commit 89c6bc0

Browse files
Allow exogenous regressions in BayesianSARIMAX (#566)
* Re-arrange docstrings * Simplify how exog data is stored * Allow exogenous variables in BayesianSARIMAX, rename SARIMA -> SARIMAX * Use constant for exogenous dim
1 parent 32a42b4 commit 89c6bc0

File tree

10 files changed

+520
-265
lines changed

10 files changed

+520
-265
lines changed

notebooks/SARMA Example.ipynb renamed to notebooks/SARIMAX Example.ipynb

Lines changed: 294 additions & 159 deletions
Large diffs are not rendered by default.

pymc_extras/statespace/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from pymc_extras.statespace.core.compile import compile_statespace
22
from pymc_extras.statespace.models import structural
33
from pymc_extras.statespace.models.ETS import BayesianETS
4-
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
4+
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMAX
55
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
66

77
__all__ = [
88
"BayesianETS",
9-
"BayesianSARIMA",
9+
"BayesianSARIMAX",
1010
"BayesianVARMAX",
1111
"compile_statespace",
1212
"structural",

pymc_extras/statespace/core/statespace.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ def __init__(
233233
self._fit_coords: dict[str, Sequence[str]] | None = None
234234
self._fit_dims: dict[str, Sequence[str]] | None = None
235235
self._fit_data: pt.TensorVariable | None = None
236+
self._fit_exog_data: dict[str, dict] = {}
236237

237238
self._needs_exog_data = None
238-
self._exog_names = []
239-
self._exog_data_info = {}
240239
self._name_to_variable = {}
241240
self._name_to_data = {}
242241

@@ -671,7 +670,7 @@ def _save_exogenous_data_info(self):
671670
pymc_mod = modelcontext(None)
672671
for data_name in self.data_names:
673672
data = pymc_mod[data_name]
674-
self._exog_data_info[data_name] = {
673+
self._fit_exog_data[data_name] = {
675674
"name": data_name,
676675
"value": data.get_value(),
677676
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
@@ -685,7 +684,7 @@ def _insert_random_variables(self):
685684
--------
686685
.. code:: python
687686
688-
ss_mod = pmss.BayesianSARIMA(order=(2, 0, 2), verbose=False, stationary_initialization=True)
687+
ss_mod = pmss.BayesianSARIMAX(order=(2, 0, 2), verbose=False, stationary_initialization=True)
689688
with pm.Model():
690689
x0 = pm.Normal('x0', size=ss_mod.k_states)
691690
ar_params = pm.Normal('ar_params', size=ss_mod.p)
@@ -1082,7 +1081,7 @@ def _kalman_filter_outputs_from_dummy_graph(
10821081

10831082
for name in self.data_names:
10841083
if name not in pm_mod:
1085-
pm.Data(**self._exog_data_info[name])
1084+
pm.Data(**self._fit_exog_data[name])
10861085

10871086
self._insert_data_variables()
10881087

@@ -1229,7 +1228,7 @@ def _sample_conditional(
12291228
method=mvn_method,
12301229
)
12311230

1232-
obs_mu = (Z @ mu[..., None]).squeeze(-1)
1231+
obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
12331232
obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H
12341233

12351234
SequenceMvNormal(
@@ -1351,7 +1350,7 @@ def _sample_unconditional(
13511350
self._insert_random_variables()
13521351

13531352
for name in self.data_names:
1354-
pm.Data(**self._exog_data_info[name])
1353+
pm.Data(**self._fit_exog_data[name])
13551354

13561355
self._insert_data_variables()
13571356

@@ -1651,7 +1650,7 @@ def sample_statespace_matrices(
16511650
self._insert_random_variables()
16521651

16531652
for name in self.data_names:
1654-
pm.Data(**self._exog_data_info[name])
1653+
pm.Data(**self.data_info[name])
16551654

16561655
self._insert_data_variables()
16571656
matrices = self.unpack_statespace()
@@ -1703,7 +1702,7 @@ def sample_filter_outputs(
17031702

17041703
if self.data_names:
17051704
for name in self.data_names:
1706-
pm.Data(**self._exog_data_info[name])
1705+
pm.Data(**self._fit_exog_data[name])
17071706

17081707
self._insert_data_variables()
17091708

@@ -1846,7 +1845,7 @@ def _validate_scenario_data(
18461845
}
18471846

18481847
if self._needs_exog_data and scenario is None:
1849-
exog_str = ",".join(self._exog_names)
1848+
exog_str = ",".join(self.data_names)
18501849
suffix = "s" if len(exog_str) > 1 else ""
18511850
raise ValueError(
18521851
f"This model was fit using exogenous data. Forecasting cannot be performed without "
@@ -1855,7 +1854,7 @@ def _validate_scenario_data(
18551854

18561855
if isinstance(scenario, dict):
18571856
for name, data in scenario.items():
1858-
if name not in self._exog_names:
1857+
if name not in self.data_names:
18591858
raise ValueError(
18601859
f"Scenario data provided for variable '{name}', which is not an exogenous variable "
18611860
f"used to fit the model."
@@ -1896,12 +1895,12 @@ def _validate_scenario_data(
18961895
# name should only be None on the first non-recursive call. We only arrive to this branch in that case
18971896
# if a non-dictionary was passed, which in turn should only happen if only a single exogenous data
18981897
# needs to be set.
1899-
if len(self._exog_names) > 1:
1898+
if len(self.data_names) > 1:
19001899
raise ValueError(
19011900
"Multiple exogenous variables were used to fit the model. Provide a dictionary of "
19021901
"scenario data instead."
19031902
)
1904-
name = self._exog_names[0]
1903+
name = self.data_names[0]
19051904

19061905
# Omit dataframe from this basic shape check so we can give more detailed information about missing columns
19071906
# in the next check
@@ -2103,7 +2102,7 @@ def _finalize_scenario_initialization(
21032102
return scenario
21042103

21052104
# This was already checked as valid
2106-
name = self._exog_names[0] if name is None else name
2105+
name = self.data_names[0] if name is None else name
21072106

21082107
# Small tidying up in the case we just have a single scenario that's already a dataframe.
21092108
if isinstance(scenario, pd.DataFrame | pd.Series):

0 commit comments

Comments
 (0)