Skip to content

Commit 9c6e939

Browse files
Respect periods exactly when provided
1 parent 1a7d961 commit 9c6e939

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,7 @@ def _validate_forecast_args(
15381538
"one or the other to avoid this warning, or pass verbose = False."
15391539
)
15401540

1541-
def _get_fit_time_index(self) -> pd.Index:
1541+
def _get_fit_time_index(self) -> pd.RangeIndex | pd.DatetimeIndex:
15421542
time_index = self._fit_coords.get(TIME_DIM, None) if self._fit_coords is not None else None
15431543
if time_index is None:
15441544
raise ValueError(
@@ -1547,6 +1547,7 @@ def _get_fit_time_index(self) -> pd.Index:
15471547

15481548
if isinstance(time_index[0], pd.Timestamp):
15491549
time_index = pd.DatetimeIndex(time_index)
1550+
time_index.freq = time_index.inferred_freq
15501551
else:
15511552
time_index = np.array(time_index)
15521553

@@ -1716,26 +1717,33 @@ def _build_forecast_index(
17161717
Index for the forecast results
17171718
"""
17181719

1719-
def get_or_create_index(x, start=None):
1720+
def get_or_create_index(x, time_index, start=None):
17201721
if isinstance(x, pd.DataFrame | pd.Series):
17211722
return x.index
17221723
elif isinstance(x, dict):
1723-
return get_or_create_index(next(iter(x.values())))
1724+
return get_or_create_index(next(iter(x.values())), time_index, start)
17241725
elif isinstance(x, np.ndarray | list | tuple):
17251726
if start is None:
17261727
raise ValueError(
17271728
"Provided scenario has no index and no start date was provided. This combination "
17281729
"is ambiguous. Please provide a start date, or add an index to the scenario."
17291730
)
1731+
is_datetime_index = isinstance(time_index, pd.DatetimeIndex)
17301732
n = x.shape[0] if isinstance(x, np.ndarray) else len(x)
1733+
1734+
if isinstance(start, int):
1735+
start = time_index[start]
1736+
if is_datetime_index:
1737+
return pd.date_range(start, periods=n, freq=time_index.freq)
17311738
return pd.RangeIndex(start, n + start, step=1, dtype="int")
1739+
17321740
else:
17331741
raise ValueError(f"{type(x)} is not a valid type for scenario data.")
17341742

17351743
x0_idx = None
17361744

17371745
if use_scenario_index:
1738-
forecast_index = get_or_create_index(scenario, start)
1746+
forecast_index = get_or_create_index(scenario, time_index, start)
17391747
is_datetime = isinstance(forecast_index, pd.DatetimeIndex)
17401748

17411749
# If the user provided an index, we want to take it as-is (without removing the start value). Instead,
@@ -1756,13 +1764,16 @@ def get_or_create_index(x, start=None):
17561764
if end is not None:
17571765
forecast_index = pd.date_range(start, end=end, freq=freq)
17581766
if periods is not None:
1759-
forecast_index = pd.date_range(start, periods=periods, freq=freq)
1767+
# date_range include both start and end, but we're going to pop off the start later (it will be
1768+
# interpreted as x0). So we need to add 1 to the periods so the user gets "periods" number of
1769+
# forecasts back
1770+
forecast_index = pd.date_range(start, periods=periods + 1, freq=freq)
17601771

17611772
else:
17621773
if end is not None:
17631774
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
17641775
if periods is not None:
1765-
forecast_index = pd.RangeIndex(start, start + periods, step=1, dtype="int")
1776+
forecast_index = pd.RangeIndex(start, start + periods + 1, step=1, dtype="int")
17661777

17671778
if is_datetime:
17681779
if forecast_index.freq != time_index.freq:
@@ -1921,7 +1932,7 @@ def forecast(
19211932
)
19221933
start = time_index[-1]
19231934

1924-
if not isinstance(scenario, dict):
1935+
if self._needs_exog_data and not isinstance(scenario, dict):
19251936
if len(self.data_names) > 1:
19261937
raise ValueError(
19271938
"Model needs more than one exogenous data to do forecasting. In this case, you must "
@@ -1950,7 +1961,6 @@ def forecast(
19501961
scenario=scenario,
19511962
use_scenario_index=use_scenario_index,
19521963
)
1953-
19541964
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
19551965
temp_coords = self._fit_coords.copy()
19561966

@@ -1999,11 +2009,12 @@ def forecast(
19992009
x0,
20002010
P0,
20012011
*matrices,
2002-
steps=len(forecast_index[:-1]),
2012+
steps=len(forecast_index),
20032013
dims=dims,
20042014
mode=self._fit_mode,
20052015
sequence_names=self.kalman_filter.seq_names,
20062016
k_endog=self.k_endog,
2017+
append_x0=False,
20072018
)
20082019

20092020
forecast_model.rvs_to_initial_values = {

0 commit comments

Comments
 (0)