Skip to content

Commit 6a6304b

Browse files
Allow exogenous variables in BayesianSARIMAX, rename SARIMA -> SARIMAX
1 parent 1abe498 commit 6a6304b

File tree

7 files changed

+402
-194
lines changed

7 files changed

+402
-194
lines changed

notebooks/SARMA Example.ipynb renamed to notebooks/SARIMAX Example.ipynb

Lines changed: 294 additions & 159 deletions
Large diffs are not rendered by default.

pymc_extras/statespace/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from pymc_extras.statespace.core.compile import compile_statespace
22
from pymc_extras.statespace.models import structural
33
from pymc_extras.statespace.models.ETS import BayesianETS
4-
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
4+
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMAX
55
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
66

77
__all__ = [
88
"BayesianETS",
9-
"BayesianSARIMA",
9+
"BayesianSARIMAX",
1010
"BayesianVARMAX",
1111
"compile_statespace",
1212
"structural",

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SARIMAX_STATE_STRUCTURES,
2323
SEASONAL_AR_PARAM_DIM,
2424
SEASONAL_MA_PARAM_DIM,
25+
TIME_DIM,
2526
)
2627

2728

@@ -38,13 +39,12 @@ def _verify_order(p, d, q, P, D, Q, S):
3839
)
3940

4041

41-
class BayesianSARIMA(PyMCStateSpace):
42+
class BayesianSARIMAX(PyMCStateSpace):
4243
r"""
4344
Seasonal AutoRegressive Integrated Moving Average with eXogenous regressors.
4445
45-
This class implements a Bayesian approach to SARIMA models, which are used for
46-
modeling univariate time series data with seasonal and non-seasonal components.
47-
The model supports exogenous regressors and can be represented in state-space form.
46+
This class implements a Bayesian approach to SARIMAX models, which are used for modeling univariate time series data
47+
with seasonal and non-seasonal components. The model supports exogenous regressors.
4848
4949
Notes
5050
-----
@@ -98,14 +98,14 @@ class BayesianSARIMA(PyMCStateSpace):
9898
9999
Examples
100100
--------
101-
The following example shows how to build an ARMA(1, 1) model -- ARIMA(1, 0, 1) -- using the BayesianSARIMA class:
101+
The following example shows how to build an ARMA(1, 1) model -- ARIMA(1, 0, 1) -- using the BayesianSARIMAX class:
102102
103103
.. code:: python
104104
105105
import pymc_extras.statespace as pmss
106106
import pymc as pm
107107
108-
ss_mod = pmss.BayesianSARIMA(order=(1, 0, 1), verbose=True)
108+
ss_mod = pmss.BayesianSARIMAX(order=(1, 0, 1), verbose=True)
109109
110110
with pm.Model(coords=ss_mod.coords) as arma_model:
111111
state_sigmas = pm.HalfNormal("sigma_state", sigma=1.0, dims=ss_mod.param_dims["sigma_state"])
@@ -140,11 +140,11 @@ def __init__(
140140
mode: str | Mode | None = None,
141141
):
142142
"""
143-
Initialize a BayesianSARIMA model.
143+
Initialize a BayesianSARIMAX model.
144144
145145
Parameters
146146
----------
147-
order : tuple(int, int, int)
147+
order : tuple of int, int, int
148148
Order of the ARIMA process. The order has the notation (p, d, q), where p is the number of autoregressive
149149
lags, q is the number of moving average components, and d is order of integration -- the number of
150150
differences needed to render the data stationary.
@@ -153,7 +153,7 @@ def __init__(
153153
This is only possible if state_structure = 'fast'. For interpretable states, the user must manually
154154
difference the data prior to calling the `build_statespace_graph` method.
155155
156-
seasonal_order : tuple(int, int, int, int), optional
156+
seasonal_order : tuple of int, int, int, int, optional
157157
Seasonal order of the SARIMA process. The order has the notation (P, D, Q, S), where P is the number of seasonal
158158
lags to include, Q is the number of seasonal innovation lags to include, and D is the number of seasonal
159159
differences to perform. S is the length of the season.
@@ -232,6 +232,12 @@ def __init__(
232232

233233
self.stationary_initialization = stationary_initialization
234234

235+
if (self.d or self.D) and self.stationary_initialization:
236+
raise ValueError(
237+
"Cannot use stationary initialization with differencing. "
238+
"Set stationary_initialization=False."
239+
)
240+
235241
self.state_structure = state_structure
236242

237243
self._p_max = max(1, self.p + self.P * self.S)
@@ -271,6 +277,7 @@ def __init__(
271277
measurement_error=measurement_error,
272278
mode=mode,
273279
)
280+
self._needs_exog_data = self.k_exog > 0
274281

275282
@property
276283
def param_names(self):
@@ -303,6 +310,17 @@ def param_names(self):
303310

304311
return names
305312

313+
@property
314+
def data_info(self) -> dict[str, dict[str, Any]]:
315+
info = {
316+
"exogenous_data": {
317+
"dims": (TIME_DIM, "exogenous"),
318+
"shape": (None, self.k_exog),
319+
}
320+
}
321+
322+
return {name: info[name] for name in self.data_names}
323+
306324
@property
307325
def param_info(self) -> dict[str, dict[str, Any]]:
308326
info = {
@@ -332,7 +350,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
332350
},
333351
"seasonal_ar_params": {"shape": (self.P,), "constraints": "None"},
334352
"seasonal_ma_params": {"shape": (self.Q,), "constraints": "None"},
335-
"beta_exog": {"shape": (self.k_exog,), "constraints": "None"},
353+
"beta_exog": {"shape": (self.k_exog,), "constraints": "None", "dims": ("exogenous",)},
336354
}
337355

338356
for name in self.param_names:
@@ -357,11 +375,14 @@ def state_names(self):
357375
else:
358376
raise NotImplementedError()
359377

360-
if self.k_exog > 0:
361-
states += ["exogenous"]
362-
363378
return states
364379

380+
@property
381+
def data_names(self) -> list[str]:
382+
if self.k_exog > 0:
383+
return ["exogenous_data"]
384+
return []
385+
365386
@property
366387
def observed_states(self):
367388
return [self.state_names[0]]
@@ -396,7 +417,7 @@ def param_dims(self):
396417
del coord_map["seasonal_ar_params"]
397418
if self.Q == 0:
398419
del coord_map["seasonal_ma_params"]
399-
if not self.k_exog == 0:
420+
if self.k_exog == 0:
400421
del coord_map["beta_exog"]
401422
if self.stationary_initialization:
402423
del coord_map["P0"]
@@ -426,7 +447,7 @@ def _stationary_initialization(self):
426447
Q = self.ssm["state_cov"]
427448
c = self.ssm["state_intercept"]
428449

429-
x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True)
450+
x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=False)
430451
P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method="bilinear")
431452

432453
return x0, P0
@@ -575,7 +596,7 @@ def make_symbolic_graph(self) -> None:
575596
"beta_exog", shape=(self.k_exog,), dtype=floatX
576597
)
577598

578-
self.ssm["obs_intercept", :, :] = exog_data @ exog_beta
599+
self.ssm["obs_intercept"] = (exog_data @ exog_beta)[:, None]
579600

580601
# Set up the state covariance matrix
581602
state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pymc_extras.statespace.models import structural
22
from pymc_extras.statespace.models.ETS import BayesianETS
3-
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
3+
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMAX
44
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
55

6-
__all__ = ["BayesianETS", "BayesianSARIMA", "BayesianVARMAX", "structural"]
6+
__all__ = ["BayesianSARIMAX", "BayesianVARMAX", "BayesianETS", "structural"]

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class AutoregressiveComponent(Component):
4646
The coefficient :math:`\rho_3` has been constrained to zero.
4747
4848
.. warning:: This class is meant to be used as a component in a structural time series model. For modeling of
49-
stationary processes with ARIMA, use ``statespace.BayesianSARIMA``.
49+
stationary processes with ARIMA, use ``statespace.BayesianSARIMAX``.
5050
5151
Examples
5252
--------

tests/statespace/models/structural/components/test_regression.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from numpy.testing import assert_allclose
77
from pytensor import config
88
from pytensor import tensor as pt
9-
from pytensor.graph.basic import explicit_graph_inputs
10-
from scipy.linalg import block_diag
119

1210
from pymc_extras.statespace.models import structural as st
1311
from tests.statespace.models.structural.conftest import _assert_basic_coords_correct
@@ -327,13 +325,14 @@ def test_regression_mixed_shared_and_not_shared():
327325
Z,
328326
np.concat(
329327
(
330-
block_diag(*[data_individual[:, np.newaxis] for _ in range(mod.k_endog)]),
331-
np.concat((data_joint[:, np.newaxis], data_joint[:, np.newaxis]), axis=1),
328+
pt.linalg.block_diag(
329+
*[data_individual[:, None] for _ in range(mod.k_endog)]
330+
).eval(),
331+
np.concat((data_joint[:, None], data_joint[:, None]), axis=1),
332332
),
333333
axis=2,
334334
),
335335
)
336336

337337
np.testing.assert_allclose(T, np.eye(mod.k_states))
338-
339338
np.testing.assert_allclose(R, np.eye(mod.k_states))

tests/statespace/models/test_SARIMAX.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from itertools import combinations
22

33
import numpy as np
4+
import pandas as pd
45
import pymc as pm
56
import pytensor
67
import pytensor.tensor as pt
78
import pytest
89
import statsmodels.api as sm
910

1011
from numpy.testing import assert_allclose, assert_array_less
12+
from pymc.testing import mock_sample_setup_and_teardown
13+
from pytensor.graph.basic import explicit_graph_inputs
1114

12-
from pymc_extras.statespace import BayesianSARIMA
15+
from pymc_extras.statespace import BayesianSARIMAX
1316
from pymc_extras.statespace.models.utilities import (
1417
make_harvey_state_names,
1518
make_SARIMA_transition_matrix,
@@ -27,6 +30,8 @@
2730
simulate_from_numpy_model,
2831
)
2932

33+
mock_sample = pytest.fixture()(mock_sample_setup_and_teardown)
34+
3035
floatX = pytensor.config.floatX
3136
ATOL = 1e-8 if floatX.endswith("64") else 1e-6
3237
RTOL = 0 if floatX.endswith("64") else 1e-6
@@ -167,7 +172,7 @@ def data():
167172

168173
@pytest.fixture(scope="session")
169174
def arima_mod():
170-
return BayesianSARIMA(order=(2, 0, 1), stationary_initialization=True, verbose=False)
175+
return BayesianSARIMAX(order=(2, 0, 1), stationary_initialization=True, verbose=False)
171176

172177

173178
@pytest.fixture(scope="session")
@@ -188,7 +193,7 @@ def pymc_mod(arima_mod):
188193

189194
@pytest.fixture(scope="session")
190195
def arima_mod_interp():
191-
return BayesianSARIMA(
196+
return BayesianSARIMAX(
192197
order=(3, 0, 3),
193198
stationary_initialization=False,
194199
verbose=False,
@@ -219,7 +224,7 @@ def pymc_mod_interp(arima_mod_interp):
219224

220225
def test_mode_argument():
221226
# Mode argument should be passed to the parent class
222-
mod = BayesianSARIMA(order=(0, 0, 3), mode="FAST_RUN", verbose=False)
227+
mod = BayesianSARIMAX(order=(0, 0, 3), mode="FAST_RUN", verbose=False)
223228
assert mod.mode == "FAST_RUN"
224229

225230

@@ -269,7 +274,7 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
269274
param_d = {name: getattr(np, floatX)(rng.normal(scale=0.1) ** 2) for name in param_names}
270275

271276
res = sm_sarimax.fit_constrained(param_d)
272-
mod = BayesianSARIMA(
277+
mod = BayesianSARIMAX(
273278
order=(p, d, q), seasonal_order=(P, D, Q, S), verbose=False, stationary_initialization=False
274279
)
275280

@@ -331,9 +336,9 @@ def test_interpretable_raises_if_d_nonzero():
331336
with pytest.raises(
332337
ValueError, match="Cannot use interpretable state structure with statespace differencing"
333338
):
334-
BayesianSARIMA(
339+
BayesianSARIMAX(
335340
order=(2, 1, 1),
336-
stationary_initialization=True,
341+
stationary_initialization=False,
337342
verbose=False,
338343
state_structure="interpretable",
339344
)
@@ -383,7 +388,7 @@ def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng):
383388

384389
for representation in SARIMAX_STATE_STRUCTURES:
385390
rng = np.random.default_rng(sum(map(ord, "representation test")))
386-
mod = BayesianSARIMA(
391+
mod = BayesianSARIMAX(
387392
order=(p, d, q),
388393
seasonal_order=(P, D, Q, S),
389394
stationary_initialization=False,
@@ -414,4 +419,52 @@ def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng):
414419
def test_invalid_order_raises(order, name):
415420
p, P, q, Q = order
416421
with pytest.raises(ValueError, match=f"The following {name} and seasonal {name} terms overlap"):
417-
BayesianSARIMA(order=(p, 0, q), seasonal_order=(P, 0, Q, 4))
422+
BayesianSARIMAX(order=(p, 0, q), seasonal_order=(P, 0, Q, 4))
423+
424+
425+
def test_SARIMA_with_exogenous(rng, mock_sample):
426+
ss_mod = BayesianSARIMAX(order=(3, 0, 1), seasonal_order=(1, 0, 0, 12), k_exog=2)
427+
428+
assert ss_mod.param_dims["beta_exog"] == ("exogenous",)
429+
assert ss_mod.data_names == ["exogenous_data"]
430+
assert ss_mod.coords["exogenous"] == ["exogenous_0", "exogenous_1"]
431+
432+
obs_intercept = ss_mod.ssm["obs_intercept"]
433+
assert obs_intercept.type.shape == (None, ss_mod.k_endog)
434+
435+
intercept_fn = pytensor.function(
436+
inputs=list(explicit_graph_inputs(obs_intercept)), outputs=obs_intercept
437+
)
438+
data_val = rng.normal(size=(100, 2)).astype(floatX)
439+
beta_val = rng.normal(size=(2,)).astype(floatX)
440+
441+
intercept_val = intercept_fn(data_val, beta_val)
442+
np.testing.assert_allclose(intercept_val, intercept_fn(data_val, beta_val))
443+
444+
data_df = pd.DataFrame(
445+
rng.normal(size=(100, 1)),
446+
index=pd.date_range(start="2020-01-01", periods=100, freq="D"),
447+
columns=["endog"],
448+
)
449+
450+
with pm.Model(coords=ss_mod.coords) as pymc_mod:
451+
pm.Data("exogenous_data", data_val, dims=["time", "exogenous"])
452+
453+
ar_params = pm.Normal("ar_params", dims=["lag_ar"])
454+
ma_params = pm.Normal("ma_params", dims=["lag_ma"])
455+
seasonal_ar_params = pm.Normal("seasonal_ar_params", dims=["seasonal_lag_ar"])
456+
457+
beta_exog = pm.Normal("beta_exog", dims=["exogenous"])
458+
459+
sigma_state = pm.Exponential("sigma_state", 1.0)
460+
461+
ss_mod.build_statespace_graph(data=data_df, save_kalman_filter_outputs_in_idata=True)
462+
idata = pm.sample(chains=2, draws=100)
463+
464+
assert "exogenous_data" in idata.constant_data
465+
assert idata.posterior.beta_exog.shape == (
466+
2,
467+
100,
468+
2,
469+
)
470+
np.testing.assert_allclose(ss_mod._fit_exog_data["exogenous_data"]["value"], data_val)

0 commit comments

Comments
 (0)