Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 294 additions & 159 deletions notebooks/SARMA Example.ipynb → notebooks/SARIMAX Example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pymc_extras/statespace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from pymc_extras.statespace.core.compile import compile_statespace
from pymc_extras.statespace.models import structural
from pymc_extras.statespace.models.ETS import BayesianETS
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMAX
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX

__all__ = [
"BayesianETS",
"BayesianSARIMA",
"BayesianSARIMAX",
"BayesianVARMAX",
"compile_statespace",
"structural",
Expand Down
27 changes: 13 additions & 14 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ def __init__(
self._fit_coords: dict[str, Sequence[str]] | None = None
self._fit_dims: dict[str, Sequence[str]] | None = None
self._fit_data: pt.TensorVariable | None = None
self._fit_exog_data: dict[str, dict] = {}

self._needs_exog_data = None
self._exog_names = []
self._exog_data_info = {}
self._name_to_variable = {}
self._name_to_data = {}

Expand Down Expand Up @@ -671,7 +670,7 @@ def _save_exogenous_data_info(self):
pymc_mod = modelcontext(None)
for data_name in self.data_names:
data = pymc_mod[data_name]
self._exog_data_info[data_name] = {
self._fit_exog_data[data_name] = {
"name": data_name,
"value": data.get_value(),
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
Expand All @@ -685,7 +684,7 @@ def _insert_random_variables(self):
--------
.. code:: python

ss_mod = pmss.BayesianSARIMA(order=(2, 0, 2), verbose=False, stationary_initialization=True)
ss_mod = pmss.BayesianSARIMAX(order=(2, 0, 2), verbose=False, stationary_initialization=True)
with pm.Model():
x0 = pm.Normal('x0', size=ss_mod.k_states)
ar_params = pm.Normal('ar_params', size=ss_mod.p)
Expand Down Expand Up @@ -1082,7 +1081,7 @@ def _kalman_filter_outputs_from_dummy_graph(

for name in self.data_names:
if name not in pm_mod:
pm.Data(**self._exog_data_info[name])
pm.Data(**self._fit_exog_data[name])

self._insert_data_variables()

Expand Down Expand Up @@ -1229,7 +1228,7 @@ def _sample_conditional(
method=mvn_method,
)

obs_mu = (Z @ mu[..., None]).squeeze(-1)
obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H

SequenceMvNormal(
Expand Down Expand Up @@ -1351,7 +1350,7 @@ def _sample_unconditional(
self._insert_random_variables()

for name in self.data_names:
pm.Data(**self._exog_data_info[name])
pm.Data(**self._fit_exog_data[name])

self._insert_data_variables()

Expand Down Expand Up @@ -1651,7 +1650,7 @@ def sample_statespace_matrices(
self._insert_random_variables()

for name in self.data_names:
pm.Data(**self._exog_data_info[name])
pm.Data(**self.data_info[name])

self._insert_data_variables()
matrices = self.unpack_statespace()
Expand Down Expand Up @@ -1703,7 +1702,7 @@ def sample_filter_outputs(

if self.data_names:
for name in self.data_names:
pm.Data(**self._exog_data_info[name])
pm.Data(**self._fit_exog_data[name])

self._insert_data_variables()

Expand Down Expand Up @@ -1846,7 +1845,7 @@ def _validate_scenario_data(
}

if self._needs_exog_data and scenario is None:
exog_str = ",".join(self._exog_names)
exog_str = ",".join(self.data_names)
suffix = "s" if len(exog_str) > 1 else ""
raise ValueError(
f"This model was fit using exogenous data. Forecasting cannot be performed without "
Expand All @@ -1855,7 +1854,7 @@ def _validate_scenario_data(

if isinstance(scenario, dict):
for name, data in scenario.items():
if name not in self._exog_names:
if name not in self.data_names:
raise ValueError(
f"Scenario data provided for variable '{name}', which is not an exogenous variable "
f"used to fit the model."
Expand Down Expand Up @@ -1896,12 +1895,12 @@ def _validate_scenario_data(
# name should only be None on the first non-recursive call. We only arrive to this branch in that case
# if a non-dictionary was passed, which in turn should only happen if only a single exogenous data
# needs to be set.
if len(self._exog_names) > 1:
if len(self.data_names) > 1:
raise ValueError(
"Multiple exogenous variables were used to fit the model. Provide a dictionary of "
"scenario data instead."
)
name = self._exog_names[0]
name = self.data_names[0]

# Omit dataframe from this basic shape check so we can give more detailed information about missing columns
# in the next check
Expand Down Expand Up @@ -2103,7 +2102,7 @@ def _finalize_scenario_initialization(
return scenario

# This was already checked as valid
name = self._exog_names[0] if name is None else name
name = self.data_names[0] if name is None else name

# Small tidying up in the case we just have a single scenario that's already a dataframe.
if isinstance(scenario, pd.DataFrame | pd.Series):
Expand Down
Loading