Skip to content

Commit 3925964

Browse files
Better handling of start date
1 parent eaf285c commit 3925964

File tree

2 files changed

+121
-37
lines changed

2 files changed

+121
-37
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,36 +1685,102 @@ def _build_forecast_index(
16851685
periods: int | None = None,
16861686
use_scenario_index: bool = False,
16871687
scenario: pd.DataFrame | np.ndarray | None = None,
1688-
) -> pd.Index:
1689-
if use_scenario_index:
1690-
if isinstance(scenario, pd.DataFrame):
1691-
return scenario.index
1692-
if isinstance(scenario, dict):
1693-
first_df = next(
1694-
(df for df in scenario.values() if isinstance(df, pd.DataFrame)), None
1695-
)
1696-
return first_df.index
1688+
) -> tuple[int | pd.Timestamp, pd.RangeIndex | pd.DatetimeIndex]:
1689+
"""
1690+
Construct a pandas Index for the requested forecast horizon.
16971691
1698-
# Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1699-
# use a range index.
1700-
is_datetime = isinstance(time_index, pd.DatetimeIndex)
1701-
forecast_index = None
1692+
Parameters
1693+
----------
1694+
time_index: pd.RangeIndex or pd.DatetimeIndex
1695+
Index of the data used to fit the model
1696+
start: int or pd.Timestamp, optional
1697+
Date from which to begin forecasting. If using a datetime index, integer start will be interpreted
1698+
as a positional index. Otherwise, start must be found inside the time_index
1699+
end: int or pd.Timestamp, optional
1700+
Date at which to end forecasting. If using a datetime index, end must be a timestamp.
1701+
periods: int, optional
1702+
Number of periods to forecast
1703+
scenario: pd.DataFrame, np.ndarray, optional
1704+
Scenario data to use for forecasting. If provided, the index of the scenario data will be used as the
1705+
forecast index. If provided, start, end, and periods will be ignored.
1706+
use_scenario_index: bool, default False
1707+
If True, the index of the scenario data will be used as the forecast index.
17021708
1703-
if is_datetime:
1704-
freq = time_index.inferred_freq
17051709
1706-
if end is not None:
1707-
forecast_index = pd.date_range(start, end=end, freq=freq)
1708-
if periods is not None:
1709-
forecast_index = pd.date_range(start, periods=periods, freq=freq)
1710+
Returns
1711+
-------
1712+
start: int | pd.TimeStamp
1713+
The starting date index or time step from which to generate the forecasts.
1714+
1715+
forecast_index: pd.DatetimeIndex or pd.RangeIndex
1716+
Index for the forecast results
1717+
"""
1718+
1719+
def get_or_create_index(x, start=None):
1720+
if isinstance(x, pd.DataFrame | pd.Series):
1721+
return x.index
1722+
elif isinstance(x, dict):
1723+
return get_or_create_index(next(iter(x.values())))
1724+
elif isinstance(x, np.ndarray | list | tuple):
1725+
if start is None:
1726+
raise ValueError(
1727+
"Provided scenario has no index and no start date was provided. This combination "
1728+
"is ambiguous. Please provide a start date, or add an index to the scenario."
1729+
)
1730+
n = x.shape[0] if isinstance(x, np.ndarray) else len(x)
1731+
return pd.RangeIndex(start, n + start, step=1, dtype="int")
1732+
else:
1733+
raise ValueError(f"{type(x)} is not a valid type for scenario data.")
1734+
1735+
x0_idx = None
1736+
1737+
if use_scenario_index:
1738+
forecast_index = get_or_create_index(scenario, start)
1739+
is_datetime = isinstance(forecast_index, pd.DatetimeIndex)
1740+
1741+
# If the user provided an index, we want to take it as-is (without removing the start value). Instead,
1742+
# step one back and use this as the start value.
1743+
delta = forecast_index.freq if is_datetime else 1
1744+
x0_idx = forecast_index[0] - delta
17101745

17111746
else:
1712-
if end is not None:
1713-
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
1714-
if periods is not None:
1715-
forecast_index = pd.RangeIndex(start, start + periods, step=1, dtype="int")
1747+
# Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1748+
# use a range index.
1749+
is_datetime = isinstance(time_index, pd.DatetimeIndex)
1750+
forecast_index = None
1751+
1752+
if is_datetime:
1753+
freq = time_index.inferred_freq
1754+
if isinstance(start, int):
1755+
start = time_index[start]
1756+
if end is not None:
1757+
forecast_index = pd.date_range(start, end=end, freq=freq)
1758+
if periods is not None:
1759+
forecast_index = pd.date_range(start, periods=periods, freq=freq)
1760+
1761+
else:
1762+
if end is not None:
1763+
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
1764+
if periods is not None:
1765+
forecast_index = pd.RangeIndex(start, start + periods, step=1, dtype="int")
1766+
1767+
if is_datetime:
1768+
if forecast_index.freq != time_index.freq:
1769+
raise ValueError(
1770+
"The frequency of the forecast index must match the frequency on the data used "
1771+
f"to fit the model. Got {forecast_index.freq}, expected {time_index.freq}"
1772+
)
1773+
1774+
if x0_idx is None:
1775+
x0_idx, forecast_index = forecast_index[0], forecast_index[1:]
1776+
if x0_idx in forecast_index:
1777+
raise ValueError("x0_idx should not be in the forecast index")
1778+
if x0_idx not in time_index:
1779+
raise ValueError("start must be in the data index used to fit the model.")
17161780

1717-
return forecast_index
1781+
# The starting value should not be included in the forecast index. It will be used only to define x0 and P0,
1782+
# and no forecast will be associated with it.
1783+
return x0_idx, forecast_index
17181784

17191785
def _finalize_scenario_initialization(
17201786
self,
@@ -1876,7 +1942,7 @@ def forecast(
18761942
verbose=verbose,
18771943
)
18781944

1879-
forecast_index = self._build_forecast_index(
1945+
t0, forecast_index = self._build_forecast_index(
18801946
time_index=time_index,
18811947
start=start,
18821948
end=end,
@@ -1892,7 +1958,6 @@ def forecast(
18921958
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
18931959
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
18941960

1895-
t0 = forecast_index[0]
18961961
t0_idx = np.flatnonzero(time_index == t0)[0]
18971962

18981963
temp_coords["data_time"] = time_index

tests/statespace/test_statespace.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
292292
def _make_time_idx(mod, use_datetime_index=True):
293293
if use_datetime_index:
294294
mod._fit_coords["time"] = nile.index
295-
time_idx = pd.DatetimeIndex(mod._fit_coords["time"].values, freq=nile.index.inferred_freq)
295+
time_idx = nile.index
296296
else:
297297
mod._fit_coords["time"] = nile.reset_index().index
298298
time_idx = pd.RangeIndex(start=0, stop=nile.shape[0], step=1)
@@ -354,34 +354,50 @@ def test_forecast_index(use_datetime_index):
354354
ss_mod._fit_coords = dict()
355355
time_idx = _make_time_idx(ss_mod, use_datetime_index)
356356

357-
# From start and end date
357+
# From start and end
358358
start = time_idx[-1]
359-
end = time_idx.shift(10)[-1] if use_datetime_index else time_idx[-1] + 11
359+
delta = pd.DateOffset(years=10) if use_datetime_index else 11
360+
end = start + delta
360361

361-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
362-
assert start in forecast_idx
363-
assert forecast_idx.shape == (11,)
362+
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
363+
assert start not in forecast_idx
364+
assert x0_index == start
365+
assert forecast_idx.shape == (10,)
364366

365367
# From start and periods
366368
start = time_idx[-1]
367-
periods = 10
369+
periods = 11
370+
371+
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
372+
assert start not in forecast_idx
373+
assert x0_index == start
374+
assert forecast_idx.shape == (10,)
368375

369-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
376+
# From integer start
377+
start = 10
378+
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
379+
delta = forecast_idx.freq if use_datetime_index else 1
380+
381+
assert x0_index == time_idx[start]
370382
assert forecast_idx.shape == (10,)
383+
assert (forecast_idx == time_idx[start + 1 : start + periods]).all()
371384

372385
# From scenario index
373386
scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
374-
forecast_idx = ss_mod._build_forecast_index(
387+
new_start, forecast_idx = ss_mod._build_forecast_index(
375388
time_index=time_idx, scenario=scenario, use_scenario_index=True
376389
)
390+
assert x0_index not in forecast_idx
391+
assert x0_index == (forecast_idx[0] - delta)
377392
assert forecast_idx.shape == (10,)
378393
assert forecast_idx.equals(scenario.index)
379394

380395
# From dictionary of scenarios
381396
scenario = {"a": pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])}
382-
forecast_idx = ss_mod._build_forecast_index(
397+
x0_index, forecast_idx = ss_mod._build_forecast_index(
383398
time_index=time_idx, scenario=scenario, use_scenario_index=True
384399
)
400+
assert x0_index == (forecast_idx[0] - delta)
385401
assert forecast_idx.shape == (10,)
386402
assert forecast_idx.equals(scenario["a"].index)
387403

@@ -484,7 +500,7 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
484500
scenario = data_type(np.zeros((10,)))
485501

486502
scenario = ss_mod._validate_scenario_data(scenario)
487-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
503+
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=11)
488504
scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
489505

490506
assert isinstance(scenario, pd.DataFrame)
@@ -662,6 +678,9 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
662678
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
663679
)
664680

681+
print(scenario.index)
682+
print(level.coords)
683+
665684
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
666685
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
667686

0 commit comments

Comments
 (0)