Skip to content

Commit 8af6a16

Browse files
Better handling of start date
1 parent 2371fec commit 8af6a16

File tree

2 files changed

+121
-37
lines changed

2 files changed

+121
-37
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,36 +1697,102 @@ def _build_forecast_index(
16971697
periods: int | None = None,
16981698
use_scenario_index: bool = False,
16991699
scenario: pd.DataFrame | np.ndarray | None = None,
1700-
) -> pd.Index:
1701-
if use_scenario_index:
1702-
if isinstance(scenario, pd.DataFrame):
1703-
return scenario.index
1704-
if isinstance(scenario, dict):
1705-
first_df = next(
1706-
(df for df in scenario.values() if isinstance(df, pd.DataFrame)), None
1707-
)
1708-
return first_df.index
1700+
) -> tuple[int | pd.Timestamp, pd.RangeIndex | pd.DatetimeIndex]:
1701+
"""
1702+
Construct a pandas Index for the requested forecast horizon.
17091703
1710-
# Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1711-
# use a range index.
1712-
is_datetime = isinstance(time_index, pd.DatetimeIndex)
1713-
forecast_index = None
1704+
Parameters
1705+
----------
1706+
time_index: pd.RangeIndex or pd.DatetimeIndex
1707+
Index of the data used to fit the model
1708+
start: int or pd.Timestamp, optional
1709+
Date from which to begin forecasting. If using a datetime index, integer start will be interpreted
1710+
as a positional index. Otherwise, start must be found inside the time_index
1711+
end: int or pd.Timestamp, optional
1712+
Date at which to end forecasting. If using a datetime index, end must be a timestamp.
1713+
periods: int, optional
1714+
Number of periods to forecast
1715+
scenario: pd.DataFrame, np.ndarray, optional
1716+
Scenario data to use for forecasting. If provided, the index of the scenario data will be used as the
1717+
forecast index. If provided, start, end, and periods will be ignored.
1718+
use_scenario_index: bool, default False
1719+
If True, the index of the scenario data will be used as the forecast index.
17141720
1715-
if is_datetime:
1716-
freq = time_index.inferred_freq
17171721
1718-
if end is not None:
1719-
forecast_index = pd.date_range(start, end=end, freq=freq)
1720-
if periods is not None:
1721-
forecast_index = pd.date_range(start, periods=periods, freq=freq)
1722+
Returns
1723+
-------
1724+
start: int | pd.TimeStamp
1725+
The starting date index or time step from which to generate the forecasts.
1726+
1727+
forecast_index: pd.DatetimeIndex or pd.RangeIndex
1728+
Index for the forecast results
1729+
"""
1730+
1731+
def get_or_create_index(x, start=None):
1732+
if isinstance(x, pd.DataFrame | pd.Series):
1733+
return x.index
1734+
elif isinstance(x, dict):
1735+
return get_or_create_index(next(iter(x.values())))
1736+
elif isinstance(x, np.ndarray | list | tuple):
1737+
if start is None:
1738+
raise ValueError(
1739+
"Provided scenario has no index and no start date was provided. This combination "
1740+
"is ambiguous. Please provide a start date, or add an index to the scenario."
1741+
)
1742+
n = x.shape[0] if isinstance(x, np.ndarray) else len(x)
1743+
return pd.RangeIndex(start, n + start, step=1, dtype="int")
1744+
else:
1745+
raise ValueError(f"{type(x)} is not a valid type for scenario data.")
1746+
1747+
x0_idx = None
1748+
1749+
if use_scenario_index:
1750+
forecast_index = get_or_create_index(scenario, start)
1751+
is_datetime = isinstance(forecast_index, pd.DatetimeIndex)
1752+
1753+
# If the user provided an index, we want to take it as-is (without removing the start value). Instead,
1754+
# step one back and use this as the start value.
1755+
delta = forecast_index.freq if is_datetime else 1
1756+
x0_idx = forecast_index[0] - delta
17221757

17231758
else:
1724-
if end is not None:
1725-
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
1726-
if periods is not None:
1727-
forecast_index = pd.RangeIndex(start, start + periods, step=1, dtype="int")
1759+
# Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1760+
# use a range index.
1761+
is_datetime = isinstance(time_index, pd.DatetimeIndex)
1762+
forecast_index = None
1763+
1764+
if is_datetime:
1765+
freq = time_index.inferred_freq
1766+
if isinstance(start, int):
1767+
start = time_index[start]
1768+
if end is not None:
1769+
forecast_index = pd.date_range(start, end=end, freq=freq)
1770+
if periods is not None:
1771+
forecast_index = pd.date_range(start, periods=periods, freq=freq)
1772+
1773+
else:
1774+
if end is not None:
1775+
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
1776+
if periods is not None:
1777+
forecast_index = pd.RangeIndex(start, start + periods, step=1, dtype="int")
1778+
1779+
if is_datetime:
1780+
if forecast_index.freq != time_index.freq:
1781+
raise ValueError(
1782+
"The frequency of the forecast index must match the frequency on the data used "
1783+
f"to fit the model. Got {forecast_index.freq}, expected {time_index.freq}"
1784+
)
1785+
1786+
if x0_idx is None:
1787+
x0_idx, forecast_index = forecast_index[0], forecast_index[1:]
1788+
if x0_idx in forecast_index:
1789+
raise ValueError("x0_idx should not be in the forecast index")
1790+
if x0_idx not in time_index:
1791+
raise ValueError("start must be in the data index used to fit the model.")
17281792

1729-
return forecast_index
1793+
# The starting value should not be included in the forecast index. It will be used only to define x0 and P0,
1794+
# and no forecast will be associated with it.
1795+
return x0_idx, forecast_index
17301796

17311797
def _finalize_scenario_initialization(
17321798
self,
@@ -1888,7 +1954,7 @@ def forecast(
18881954
verbose=verbose,
18891955
)
18901956

1891-
forecast_index = self._build_forecast_index(
1957+
t0, forecast_index = self._build_forecast_index(
18921958
time_index=time_index,
18931959
start=start,
18941960
end=end,
@@ -1904,7 +1970,6 @@ def forecast(
19041970
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
19051971
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
19061972

1907-
t0 = forecast_index[0]
19081973
t0_idx = np.flatnonzero(time_index == t0)[0]
19091974

19101975
temp_coords["data_time"] = time_index

tests/statespace/test_statespace.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
292292
def _make_time_idx(mod, use_datetime_index=True):
293293
if use_datetime_index:
294294
mod._fit_coords["time"] = nile.index
295-
time_idx = pd.DatetimeIndex(mod._fit_coords["time"].values, freq=nile.index.inferred_freq)
295+
time_idx = nile.index
296296
else:
297297
mod._fit_coords["time"] = nile.reset_index().index
298298
time_idx = pd.RangeIndex(start=0, stop=nile.shape[0], step=1)
@@ -354,34 +354,50 @@ def test_forecast_index(use_datetime_index):
354354
ss_mod._fit_coords = dict()
355355
time_idx = _make_time_idx(ss_mod, use_datetime_index)
356356

357-
# From start and end date
357+
# From start and end
358358
start = time_idx[-1]
359-
end = time_idx.shift(10)[-1] if use_datetime_index else time_idx[-1] + 11
359+
delta = pd.DateOffset(years=10) if use_datetime_index else 11
360+
end = start + delta
360361

361-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
362-
assert start in forecast_idx
363-
assert forecast_idx.shape == (11,)
362+
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end)
363+
assert start not in forecast_idx
364+
assert x0_index == start
365+
assert forecast_idx.shape == (10,)
364366

365367
# From start and periods
366368
start = time_idx[-1]
367-
periods = 10
369+
periods = 11
370+
371+
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
372+
assert start not in forecast_idx
373+
assert x0_index == start
374+
assert forecast_idx.shape == (10,)
368375

369-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
376+
# From integer start
377+
start = 10
378+
x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods)
379+
delta = forecast_idx.freq if use_datetime_index else 1
380+
381+
assert x0_index == time_idx[start]
370382
assert forecast_idx.shape == (10,)
383+
assert (forecast_idx == time_idx[start + 1 : start + periods]).all()
371384

372385
# From scenario index
373386
scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])
374-
forecast_idx = ss_mod._build_forecast_index(
387+
new_start, forecast_idx = ss_mod._build_forecast_index(
375388
time_index=time_idx, scenario=scenario, use_scenario_index=True
376389
)
390+
assert x0_index not in forecast_idx
391+
assert x0_index == (forecast_idx[0] - delta)
377392
assert forecast_idx.shape == (10,)
378393
assert forecast_idx.equals(scenario.index)
379394

380395
# From dictionary of scenarios
381396
scenario = {"a": pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])}
382-
forecast_idx = ss_mod._build_forecast_index(
397+
x0_index, forecast_idx = ss_mod._build_forecast_index(
383398
time_index=time_idx, scenario=scenario, use_scenario_index=True
384399
)
400+
assert x0_index == (forecast_idx[0] - delta)
385401
assert forecast_idx.shape == (10,)
386402
assert forecast_idx.equals(scenario["a"].index)
387403

@@ -484,7 +500,7 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
484500
scenario = data_type(np.zeros((10,)))
485501

486502
scenario = ss_mod._validate_scenario_data(scenario)
487-
forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10)
503+
t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=11)
488504
scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx)
489505

490506
assert isinstance(scenario, pd.DataFrame)
@@ -662,6 +678,9 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
662678
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
663679
)
664680

681+
print(scenario.index)
682+
print(level.coords)
683+
665684
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
666685
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
667686

0 commit comments

Comments
 (0)