Skip to content

Commit f92782c

Browse files
Improve scenario forecasting test
1 parent b6fcbc4 commit f92782c

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

tests/statespace/test_statespace.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -647,14 +647,30 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
647647
scenario.iloc[5, 0] = 1e9
648648

649649
forecast_idata = exog_ss_mod.forecast(
650-
idata_exog, periods=10, random_seed=rng, scenario={"data_exog": scenario}
650+
idata_exog, periods=10, random_seed=rng, scenario=scenario
651651
)
652652

653-
# TODO: Why does it end up on t=7?
654-
t_5 = forecast_idata.forecast_observed.isel(time=7, observed_state=0).to_numpy()
655-
not_t_5 = (
656-
forecast_idata.forecast_observed.isel(time=np.arange(10) != 7, observed_state=0)
657-
.mean(dim="time")
658-
.to_numpy()
653+
components = exog_ss_mod.extract_components_from_idata(forecast_idata)
654+
level = components.forecast_latent.sel(state="LevelTrend[level]")
655+
betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
656+
657+
scenario.index.name = "time"
658+
scenario_xr = (
659+
scenario.unstack()
660+
.to_xarray()
661+
.rename({"level_0": "state"})
662+
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
659663
)
660-
assert t_5.shape == not_t_5.shape
664+
665+
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
666+
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
667+
668+
assert_allclose(regression_effect, regression_effect_expected)
669+
670+
# t_5 = forecast_idata.forecast_observed.isel(time=7, observed_state=0).to_numpy()
671+
# not_t_5 = (
672+
# forecast_idata.forecast_observed.isel(time=np.arange(10) != 7, observed_state=0)
673+
# .mean(dim="time")
674+
# .to_numpy()
675+
# )
676+
# assert t_5.shape == not_t_5.shape

0 commit comments

Comments
 (0)