Skip to content

Commit d0b6fd6

Browse files
committed
added multivariate forecast with exogenous variables test
1 parent 669bd10 commit d0b6fd6

File tree

2 files changed

+111
-8
lines changed

2 files changed

+111
-8
lines changed

notebooks/multivariate_ssm.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,14 +2064,6 @@
20642064
" nuts_sampler=\"nutpie\", nuts_sampler_kwargs={\"backend\": \"JAX\", \"gradient_backend\": \"JAX\"}\n",
20652065
" )"
20662066
]
2067-
},
2068-
{
2069-
"cell_type": "code",
2070-
"execution_count": null,
2071-
"id": "86dd9e4e",
2072-
"metadata": {},
2073-
"outputs": [],
2074-
"source": []
20752067
}
20762068
],
20772069
"metadata": {

tests/statespace/core/test_statespace.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ def exog_data(rng):
154154
return df.set_index("date")
155155

156156

157+
@pytest.fixture(scope="session")
158+
def exog_data_mv(rng):
159+
# simulate data
160+
df = pd.DataFrame(
161+
{
162+
"date": pd.date_range(start="2023-05-01", end="2023-05-10", freq="D"),
163+
"x1": rng.choice(2, size=10, replace=True).astype(float),
164+
"y1": rng.normal(size=(10,)),
165+
"y2": rng.normal(size=(10,)),
166+
}
167+
)
168+
169+
df.loc[[1, 3, 9], ["y1"]] = np.nan
170+
df.loc[[3, 5, 7], ["y2"]] = np.nan
171+
return df.set_index("date")
172+
173+
157174
@pytest.fixture(scope="session")
158175
def exog_ss_mod(exog_data):
159176
level_trend = st.LevelTrendComponent(name="trend", order=1, innovations_order=[0])
@@ -168,6 +185,23 @@ def exog_ss_mod(exog_data):
168185
return combined_model.build()
169186

170187

188+
@pytest.fixture(scope="session")
189+
def exog_ss_mod_mv(exog_data_mv):
190+
level_trend = st.LevelTrendComponent(
191+
name="trend", order=1, innovations_order=[0], observed_state_names=["y1", "y2"]
192+
)
193+
exog = st.RegressionComponent(
194+
name="exog", # Name of this exogenous variable component
195+
k_exog=1, # Only one exogenous variable now
196+
innovations=False, # Typically fixed effect (no stochastic evolution)
197+
state_names=exog_data_mv[["x1"]].columns.tolist(),
198+
observed_state_names=["y1", "y2"],
199+
)
200+
201+
combined_model = level_trend + exog
202+
return combined_model.build()
203+
204+
171205
@pytest.fixture(scope="session")
172206
def ss_mod_multi_component(rng):
173207
ll = st.LevelTrendComponent(
@@ -208,6 +242,29 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
208242
return struct_model
209243

210244

245+
@pytest.fixture(scope="session")
246+
def exog_pymc_mod_mv(exog_ss_mod_mv, exog_data_mv):
247+
# define pymc model
248+
with pm.Model(coords=exog_ss_mod_mv.coords) as struct_model:
249+
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
250+
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
251+
252+
initial_trend = pm.Normal(
253+
"initial_trend", mu=[0], sigma=[0.005], dims=["endog_trend", "state_trend"]
254+
)
255+
256+
data_exog = pm.Data(
257+
"data_exog", exog_data_mv["x1"].values[:, None], dims=["time", "state_exog"]
258+
)
259+
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["endog_exog", "state_exog"])
260+
261+
exog_ss_mod_mv.build_statespace_graph(
262+
exog_data_mv[["y1", "y2"]], save_kalman_filter_outputs_in_idata=True
263+
)
264+
265+
return struct_model
266+
267+
211268
@pytest.fixture(scope="session")
212269
def pymc_mod_no_exog(ss_mod_no_exog, rng):
213270
y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"])
@@ -299,6 +356,15 @@ def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
299356
return idata
300357

301358

359+
@pytest.fixture(scope="session")
360+
def idata_exog_mv(exog_pymc_mod_mv, rng, mock_pymc_sample):
361+
with exog_pymc_mod_mv:
362+
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
363+
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
364+
idata.extend(idata_prior)
365+
return idata
366+
367+
302368
@pytest.fixture(scope="session")
303369
def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
304370
with pymc_mod_no_exog:
@@ -1002,6 +1068,51 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
10021068
assert_allclose(regression_effect, regression_effect_expected)
10031069

10041070

1071+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
1072+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1073+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
1074+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
1075+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
1076+
@pytest.mark.parametrize("start", [None, -1, 5])
1077+
def test_forecast_with_exog_data_mv(rng, exog_ss_mod_mv, idata_exog_mv, start):
1078+
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
1079+
scenario.iloc[5, 0] = 1e9
1080+
1081+
forecast_idata = exog_ss_mod_mv.forecast(
1082+
idata_exog_mv, start=start, periods=10, random_seed=rng, scenario=scenario
1083+
)
1084+
1085+
components = exog_ss_mod_mv.extract_components_from_idata(forecast_idata)
1086+
level_y1 = components.forecast_latent.sel(state="trend[level[y1]]")
1087+
level_y2 = components.forecast_latent.sel(state="trend[level[y2]]")
1088+
betas_y1 = components.forecast_latent.sel(state=["exog[x1[y1]]"])
1089+
betas_y2 = components.forecast_latent.sel(state=["exog[x1[y2]]"])
1090+
1091+
scenario.index.name = "time"
1092+
scenario_xr_y1 = (
1093+
scenario.unstack()
1094+
.to_xarray()
1095+
.rename({"level_0": "state"})
1096+
.assign_coords(state=["exog[x1[y1]]"])
1097+
)
1098+
1099+
scenario_xr_y2 = (
1100+
scenario.unstack()
1101+
.to_xarray()
1102+
.rename({"level_0": "state"})
1103+
.assign_coords(state=["exog[x1[y2]]"])
1104+
)
1105+
1106+
regression_effect_y1 = forecast_idata.forecast_observed.isel(observed_state=0) - level_y1
1107+
regression_effect_expected_y1 = (betas_y1 * scenario_xr_y1).sum(dim=["state"])
1108+
1109+
regression_effect_y2 = forecast_idata.forecast_observed.isel(observed_state=1) - level_y2
1110+
regression_effect_expected_y2 = (betas_y2 * scenario_xr_y2).sum(dim=["state"])
1111+
1112+
np.testing.assert_allclose(regression_effect_y1, regression_effect_expected_y1)
1113+
np.testing.assert_allclose(regression_effect_y2, regression_effect_expected_y2)
1114+
1115+
10051116
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
10061117
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
10071118
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")

0 commit comments

Comments
 (0)