Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions pymc_extras/statespace/models/VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,6 @@ def __init__(
The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
and "cholesky". See the docs for kalman filters for more details.

state_structure: str, default "fast"
How to represent the state-space system. When "interpretable", each element of the state vector will have a
precise meaning as either lagged data, innovations, or lagged innovations. This comes at the cost of a larger
state vector, which may hurt performance.

When "fast", states are combined to minimize the dimension of the state vector, but lags and innovations are
mixed together as a result. Only the first state (the modeled timeseries) will have an obvious interpretation
in this case.

measurement_error: bool, default True
If true, a measurement error term is added to the model.

Expand All @@ -181,8 +172,10 @@ def __init__(
if len(endog_names) != k_endog:
raise ValueError("Length of provided endog_names does not match provided k_endog")

needs_exog_data = False

if k_exog is not None and not isinstance(k_exog, int | dict):
raise ValueError("If not None, k_endog must be either an int or a dict")
raise ValueError("If not None, k_exog must be either an int or a dict")
if exog_state_names is not None and not isinstance(exog_state_names, list | dict):
raise ValueError("If not None, exog_state_names must be either a list or a dict")

Expand All @@ -208,6 +201,7 @@ def __init__(
"If both k_endog and exog_state_names are provided, lengths of exog_state_names "
"lists must match corresponding values in k_exog"
)
needs_exog_data = True

if k_exog is not None and exog_state_names is None:
if isinstance(k_exog, int):
Expand All @@ -216,12 +210,14 @@ def __init__(
exog_state_names = {
name: [f"{name}_exogenous_{i}" for i in range(k)] for name, k in k_exog.items()
}
needs_exog_data = True

if k_exog is None and exog_state_names is not None:
if isinstance(exog_state_names, list):
k_exog = len(exog_state_names)
elif isinstance(exog_state_names, dict):
k_exog = {name: len(names) for name, names in exog_state_names.items()}
needs_exog_data = True

# If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
# then we can drop back to the list case.
Expand Down Expand Up @@ -254,6 +250,8 @@ def __init__(
mode=mode,
)

self._needs_exog_data = needs_exog_data

# Save counts of the number of parameters in each category
self.param_counts = {
"x0": k_states * (1 - self.stationary_initialization),
Expand Down Expand Up @@ -337,7 +335,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:

@property
def data_info(self) -> dict[str, dict[str, Any]]:
info = None
info = {}

if isinstance(self.exog_state_names, list):
info = {
Expand Down
140 changes: 103 additions & 37 deletions tests/statespace/models/test_VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,14 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
def test_impulse_response(parameters, varma_mod, idata, rng):
irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters)

assert not np.any(np.isnan(irf.irf.values))
assert np.isfinite(irf.irf.values).all()


def test_forecast(varma_mod, idata, rng):
forecast = varma_mod.forecast(idata.prior, periods=10, random_seed=rng)

assert np.isfinite(forecast.forecast_latent.values).all()
assert np.isfinite(forecast.forecast_observed.values).all()


class TestVARMAXWithExogenous:
Expand Down Expand Up @@ -436,42 +443,8 @@ def test_create_varmax_with_exogenous_raises_if_args_disagree(self, data):
stationary_initialization=False,
)

@pytest.mark.parametrize(
"k_exog, exog_state_names",
[
(2, None),
(None, ["foo", "bar"]),
(None, {"y1": ["a", "b"], "y2": ["c"]}),
],
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
endog_names = ["y1", "y2", "y3"]
n_obs = 50
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")

y = rng.normal(size=(n_obs, len(endog_names)))
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)

if isinstance(exog_state_names, dict):
exog_data = {
f"{name}_exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
for name, exog_names in exog_state_names.items()
}
else:
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
exog_data = {
"exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
}
def _build_varmax(self, df, k_exog, exog_state_names, exog_data):
endog_names = df.columns.values.tolist()

mod = BayesianVARMAX(
endog_names=endog_names,
Expand Down Expand Up @@ -512,6 +485,47 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):

mod.build_statespace_graph(data=df)

return mod, m

@pytest.mark.parametrize(
"k_exog, exog_state_names",
[
(2, None),
(None, ["foo", "bar"]),
(None, {"y1": ["a", "b"], "y2": ["c"]}),
],
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
endog_names = ["y1", "y2", "y3"]
n_obs = 50
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")

y = rng.normal(size=(n_obs, len(endog_names)))
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)

if isinstance(exog_state_names, dict):
exog_data = {
f"{name}_exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
for name, exog_names in exog_state_names.items()
}
else:
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
exog_data = {
"exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
}

mod, m = self._build_varmax(df, k_exog, exog_state_names, exog_data)

with freeze_dims_and_data(m):
prior = pm.sample_prior_predictive(
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
Expand Down Expand Up @@ -543,3 +557,55 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
obs_intercept.append(np.zeros_like(obs_intercept[0]))

np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2)

@pytest.mark.filterwarnings("ignore::UserWarning")
def test_forecast_with_exog(self, rng):
endog_names = ["y1", "y2", "y3"]
n_obs = 50
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")

y = rng.normal(size=(n_obs, len(endog_names)))
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)

mod, m = self._build_varmax(
df,
k_exog=2,
exog_state_names=None,
exog_data={
"exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, 2)).astype(floatX),
columns=["exogenous_0", "exogenous_1"],
index=time_idx,
)
},
)

assert mod._needs_exog_data

with freeze_dims_and_data(m):
prior = pm.sample_prior_predictive(
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
)

with pytest.raises(
ValueError,
match="This model was fit using exogenous data. Forecasting cannot be performed "
"without providing scenario data",
):
mod.forecast(prior.prior, periods=10, random_seed=rng)

forecast = mod.forecast(
prior.prior,
periods=10,
random_seed=rng,
scenario={
"exogenous_data": pd.DataFrame(
rng.normal(size=(10, 2)).astype(floatX),
columns=["exogenous_0", "exogenous_1"],
index=pd.date_range(start=df.index[-1], periods=10, freq="D"),
)
},
)

assert np.isfinite(forecast.forecast_latent.values).all()
assert np.isfinite(forecast.forecast_observed.values).all()
Loading