Skip to content

Commit 99fcf22

Browse files
Improve forecast-related tests
1 parent 9c6e939 commit 99fcf22

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

tests/statespace/test_statespace.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def test_forecast_index(use_datetime_index):
366366

367367
# From start and periods
368368
start = time_idx[-1]
369-
periods = 11
369+
periods = 10
370370

371371
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
372372
assert start not in forecast_idx
@@ -380,7 +380,7 @@ def test_forecast_index(use_datetime_index):
380380

381381
assert x0_index == time_idx[start]
382382
assert forecast_idx.shape == (10,)
383-
assert (forecast_idx == time_idx[start + 1 : start + periods]).all()
383+
assert (forecast_idx == time_idx[start + 1 : start + periods + 1]).all()
384384

385385
# From scenario index
386386
scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
@@ -500,7 +500,7 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
500500
scenario = data_type(np.zeros((10,)))
501501

502502
scenario = ss_mod._validate_scenario_data(scenario)
503-
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=11)
503+
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
504504
scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
505505

506506
assert isinstance(scenario, pd.DataFrame)
@@ -514,11 +514,8 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
514514
ids=["series", "dataframe", "array", "list", "tuple"],
515515
)
516516
@pytest.mark.parametrize("use_datetime_index", [True, False])
517-
def test_finalize_secenario_dict(data_type, use_datetime_index):
518-
if data_type is pd.DataFrame:
519-
# Ensure dataframes have the correct column name
520-
data_type = partial(pd.DataFrame, columns=["column_1"])
521-
517+
@pytest.mark.parametrize("use_scenario_index", [True, False])
518+
def test_finalize_secenario_dict(data_type, use_datetime_index, use_scenario_index):
522519
data_info = {
523520
"a": {"shape": (None, 1), "dims": ("time", "features_a")},
524521
"b": {"shape": (None, 2), "dims": ("time", "features_b")},
@@ -534,13 +531,38 @@ def test_finalize_secenario_dict(data_type, use_datetime_index):
534531
ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1", "column_2"])
535532
time_idx = _make_time_idx(ss_mod, use_datetime_index)
536533

534+
initial_index = (
535+
pd.date_range(start=time_idx[-1], periods=10, freq=time_idx.freq)
536+
if use_datetime_index
537+
else pd.RangeIndex(time_idx[-1], time_idx[-1] + 10, 1)
538+
)
539+
540+
if data_type is pd.DataFrame:
541+
# Ensure dataframes have the correct column name
542+
data_type = partial(pd.DataFrame, columns=["column_1"], index=initial_index)
543+
elif data_type is pd.Series:
544+
data_type = partial(pd.Series, index=initial_index)
545+
537546
scenario = {
538547
"a": data_type(np.zeros((10,))),
539-
"b": pd.DataFrame(np.zeros((10, 2)), columns=ss_mod._fit_coords["features_b"]),
548+
"b": pd.DataFrame(
549+
np.zeros((10, 2)), columns=ss_mod._fit_coords["features_b"], index=initial_index
550+
),
540551
}
541552

542553
scenario = ss_mod._validate_scenario_data(scenario)
543-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
554+
555+
if use_scenario_index and data_type not in [np.array, list, tuple]:
556+
t0, forecast_idx = ss_mod._build_forecast_index(
557+
time_idx, scenario=scenario, periods=10, use_scenario_index=True
558+
)
559+
elif use_scenario_index and data_type in [np.array, list, tuple]:
560+
t0, forecast_idx = ss_mod._build_forecast_index(
561+
time_idx, scenario=scenario, start=-1, periods=10, use_scenario_index=True
562+
)
563+
else:
564+
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
565+
544566
scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
545567

546568
assert list(scenario.keys()) == ["a", "b"]
@@ -678,18 +700,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
678700
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
679701
)
680702

681-
print(scenario.index)
682-
print(level.coords)
683-
684703
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
685704
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
686705

687706
assert_allclose(regression_effect, regression_effect_expected)
688-
689-
# t_5 = forecast_idata.forecast_observed.isel(time=7, observed_state=0).to_numpy()
690-
# not_t_5 = (
691-
# forecast_idata.forecast_observed.isel(time=np.arange(10) != 7, observed_state=0)
692-
# .mean(dim="time")
693-
# .to_numpy()
694-
# )
695-
# assert t_5.shape == not_t_5.shape

0 commit comments

Comments
 (0)