Skip to content

Commit 9f5c16c

Browse files
Respect periods exactly when provided
1 parent c0966e1 commit 9f5c16c

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def _validate_forecast_args(
15501550
"one or the other to avoid this warning, or pass verbose = False."
15511551
)
15521552

1553-
def _get_fit_time_index(self) -> pd.Index:
1553+
def _get_fit_time_index(self) -> pd.RangeIndex | pd.DatetimeIndex:
15541554
time_index = self._fit_coords.get(TIME_DIM, None) if self._fit_coords is not None else None
15551555
if time_index is None:
15561556
raise ValueError(
@@ -1559,6 +1559,7 @@ def _get_fit_time_index(self) -> pd.Index:
15591559

15601560
if isinstance(time_index[0], pd.Timestamp):
15611561
time_index = pd.DatetimeIndex(time_index)
1562+
time_index.freq = time_index.inferred_freq
15621563
else:
15631564
time_index = np.array(time_index)
15641565

@@ -1728,26 +1729,33 @@ def _build_forecast_index(
17281729
Index for the forecast results
17291730
"""
17301731

1731-
def get_or_create_index(x, start=None):
1732+
def get_or_create_index(x, time_index, start=None):
17321733
if isinstance(x, pd.DataFrame | pd.Series):
17331734
return x.index
17341735
elif isinstance(x, dict):
1735-
return get_or_create_index(next(iter(x.values())))
1736+
return get_or_create_index(next(iter(x.values())), time_index, start)
17361737
elif isinstance(x, np.ndarray | list | tuple):
17371738
if start is None:
17381739
raise ValueError(
17391740
"Provided scenario has no index and no start date was provided. This combination "
17401741
"is ambiguous. Please provide a start date, or add an index to the scenario."
17411742
)
1743+
is_datetime_index = isinstance(time_index, pd.DatetimeIndex)
17421744
n = x.shape[0] if isinstance(x, np.ndarray) else len(x)
1745+
1746+
if isinstance(start, int):
1747+
start = time_index[start]
1748+
if is_datetime_index:
1749+
return pd.date_range(start, periods=n, freq=time_index.freq)
17431750
return pd.RangeIndex(start, n + start, step=1, dtype="int")
1751+
17441752
else:
17451753
raise ValueError(f"{type(x)} is not a valid type for scenario data.")
17461754

17471755
x0_idx = None
17481756

17491757
if use_scenario_index:
1750-
forecast_index = get_or_create_index(scenario, start)
1758+
forecast_index = get_or_create_index(scenario, time_index, start)
17511759
is_datetime = isinstance(forecast_index, pd.DatetimeIndex)
17521760

17531761
# If the user provided an index, we want to take it as-is (without removing the start value). Instead,
@@ -1768,13 +1776,16 @@ def get_or_create_index(x, start=None):
17681776
if end is not None:
17691777
forecast_index = pd.date_range(start, end=end, freq=freq)
17701778
if periods is not None:
1771-
forecast_index = pd.date_range(start, periods=periods, freq=freq)
1779+
# date_range include both start and end, but we're going to pop off the start later (it will be
1780+
# interpreted as x0). So we need to add 1 to the periods so the user gets "periods" number of
1781+
# forecasts back
1782+
forecast_index = pd.date_range(start, periods=periods + 1, freq=freq)
17721783

17731784
else:
17741785
if end is not None:
17751786
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
17761787
if periods is not None:
1777-
forecast_index = pd.RangeIndex(start, start + periods, step=1, dtype="int")
1788+
forecast_index = pd.RangeIndex(start, start + periods + 1, step=1, dtype="int")
17781789

17791790
if is_datetime:
17801791
if forecast_index.freq != time_index.freq:
@@ -1933,7 +1944,7 @@ def forecast(
19331944
)
19341945
start = time_index[-1]
19351946

1936-
if not isinstance(scenario, dict):
1947+
if self._needs_exog_data and not isinstance(scenario, dict):
19371948
if len(self.data_names) > 1:
19381949
raise ValueError(
19391950
"Model needs more than one exogenous data to do forecasting. In this case, you must "
@@ -1962,7 +1973,6 @@ def forecast(
19621973
scenario=scenario,
19631974
use_scenario_index=use_scenario_index,
19641975
)
1965-
19661976
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
19671977
temp_coords = self._fit_coords.copy()
19681978

@@ -2011,11 +2021,12 @@ def forecast(
20112021
x0,
20122022
P0,
20132023
*matrices,
2014-
steps=len(forecast_index[:-1]),
2024+
steps=len(forecast_index),
20152025
dims=dims,
20162026
mode=self._fit_mode,
20172027
sequence_names=self.kalman_filter.seq_names,
20182028
k_endog=self.k_endog,
2029+
append_x0=False,
20192030
)
20202031

20212032
forecast_model.rvs_to_initial_values = {

0 commit comments

Comments
 (0)