Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc_extras/statespace/models/VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def __init__(
mode=mode,
)

self._needs_exog_data = exog_state_names is not None and len(exog_state_names) > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is len(exog_state_names)>0 intended to handle when a user inputs an empty list? What about if a user inputs a dictionary with an empty list? I think if you have something like {'endog1': [], 'endog2': []} then you will get a True.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I switched it to be a bit more robust to all possible cases.


# Save counts of the number of parameters in each category
self.param_counts = {
"x0": k_states * (1 - self.stationary_initialization),
Expand Down Expand Up @@ -337,7 +339,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:

@property
def data_info(self) -> dict[str, dict[str, Any]]:
info = None
info = {}

if isinstance(self.exog_state_names, list):
info = {
Expand Down
138 changes: 101 additions & 37 deletions tests/statespace/models/test_VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,14 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
def test_impulse_response(parameters, varma_mod, idata, rng):
irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters)

assert not np.any(np.isnan(irf.irf.values))
assert np.isfinite(irf.irf.values).all()


def test_forecast(varma_mod, idata, rng):
forecast = varma_mod.forecast(idata.prior, periods=10, random_seed=rng)

assert np.isfinite(forecast.forecast_latent.values).all()
assert np.isfinite(forecast.forecast_observed.values).all()


class TestVARMAXWithExogenous:
Expand Down Expand Up @@ -436,42 +443,8 @@ def test_create_varmax_with_exogenous_raises_if_args_disagree(self, data):
stationary_initialization=False,
)

@pytest.mark.parametrize(
"k_exog, exog_state_names",
[
(2, None),
(None, ["foo", "bar"]),
(None, {"y1": ["a", "b"], "y2": ["c"]}),
],
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
endog_names = ["y1", "y2", "y3"]
n_obs = 50
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")

y = rng.normal(size=(n_obs, len(endog_names)))
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)

if isinstance(exog_state_names, dict):
exog_data = {
f"{name}_exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
for name, exog_names in exog_state_names.items()
}
else:
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
exog_data = {
"exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
}
def _build_varmax(self, df, k_exog, exog_state_names, exog_data):
endog_names = df.columns.values.tolist()

mod = BayesianVARMAX(
endog_names=endog_names,
Expand Down Expand Up @@ -512,6 +485,47 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):

mod.build_statespace_graph(data=df)

return mod, m

@pytest.mark.parametrize(
"k_exog, exog_state_names",
[
(2, None),
(None, ["foo", "bar"]),
(None, {"y1": ["a", "b"], "y2": ["c"]}),
],
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
)
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
endog_names = ["y1", "y2", "y3"]
n_obs = 50
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")

y = rng.normal(size=(n_obs, len(endog_names)))
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)

if isinstance(exog_state_names, dict):
exog_data = {
f"{name}_exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
for name, exog_names in exog_state_names.items()
}
else:
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
exog_data = {
"exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
columns=exog_names,
index=time_idx,
)
}

mod, m = self._build_varmax(df, k_exog, exog_state_names, exog_data)

with freeze_dims_and_data(m):
prior = pm.sample_prior_predictive(
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
Expand Down Expand Up @@ -543,3 +557,53 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
obs_intercept.append(np.zeros_like(obs_intercept[0]))

np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2)

@pytest.mark.filterwarnings("ignore::UserWarning")
def test_forecast_with_exog(self, rng):
endog_names = ["y1", "y2", "y3"]
n_obs = 50
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")

y = rng.normal(size=(n_obs, len(endog_names)))
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)

mod, m = self._build_varmax(
df,
k_exog=2,
exog_state_names=None,
exog_data={
"exogenous_data": pd.DataFrame(
rng.normal(size=(n_obs, 2)).astype(floatX),
columns=["exogenous_0", "exogenous_1"],
index=time_idx,
)
},
)

with freeze_dims_and_data(m):
prior = pm.sample_prior_predictive(
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
)

with pytest.raises(
ValueError,
match="This model was fit using exogenous data. Forecasting cannot be performed "
"without providing scenario data",
):
mod.forecast(prior.prior, periods=10, random_seed=rng)

forecast = mod.forecast(
prior.prior,
periods=10,
random_seed=rng,
scenario={
"exogenous_data": pd.DataFrame(
rng.normal(size=(10, 2)).astype(floatX),
columns=["exogenous_0", "exogenous_1"],
index=pd.date_range(start=df.index[-1], periods=10, freq="D"),
)
},
)

assert np.isfinite(forecast.forecast_latent.values).all()
assert np.isfinite(forecast.forecast_observed.values).all()
Loading