@@ -1685,36 +1685,102 @@ def _build_forecast_index(
16851685 periods : int | None = None ,
16861686 use_scenario_index : bool = False ,
16871687 scenario : pd .DataFrame | np .ndarray | None = None ,
1688- ) -> pd .Index :
1689- if use_scenario_index :
1690- if isinstance (scenario , pd .DataFrame ):
1691- return scenario .index
1692- if isinstance (scenario , dict ):
1693- first_df = next (
1694- (df for df in scenario .values () if isinstance (df , pd .DataFrame )), None
1695- )
1696- return first_df .index
1688+ ) -> tuple [int | pd .Timestamp , pd .RangeIndex | pd .DatetimeIndex ]:
1689+ """
1690+ Construct a pandas Index for the requested forecast horizon.
16971691
1698- # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1699- # use a range index.
1700- is_datetime = isinstance (time_index , pd .DatetimeIndex )
1701- forecast_index = None
1692+ Parameters
1693+ ----------
1694+ time_index: pd.RangeIndex or pd.DatetimeIndex
1695+ Index of the data used to fit the model
1696+ start: int or pd.Timestamp, optional
1697+ Date from which to begin forecasting. If using a datetime index, integer start will be interpreted
1698+ as a positional index. Otherwise, start must be found inside the time_index
1699+ end: int or pd.Timestamp, optional
1700+ Date at which to end forecasting. If using a datetime index, end must be a timestamp.
1701+ periods: int, optional
1702+ Number of periods to forecast
1703+ scenario: pd.DataFrame, np.ndarray, optional
1704+ Scenario data to use for forecasting. If provided, the index of the scenario data will be used as the
1705+ forecast index. If provided, start, end, and periods will be ignored.
1706+ use_scenario_index: bool, default False
1707+ If True, the index of the scenario data will be used as the forecast index.
17021708
1703- if is_datetime :
1704- freq = time_index .inferred_freq
17051709
1706- if end is not None :
1707- forecast_index = pd .date_range (start , end = end , freq = freq )
1708- if periods is not None :
1709- forecast_index = pd .date_range (start , periods = periods , freq = freq )
1710+ Returns
1711+ -------
1712+ start: int | pd.TimeStamp
1713+ The starting date index or time step from which to generate the forecasts.
1714+
1715+ forecast_index: pd.DatetimeIndex or pd.RangeIndex
1716+ Index for the forecast results
1717+ """
1718+
1719+ def get_or_create_index (x , start = None ):
1720+ if isinstance (x , pd .DataFrame | pd .Series ):
1721+ return x .index
1722+ elif isinstance (x , dict ):
1723+ return get_or_create_index (next (iter (x .values ())))
1724+ elif isinstance (x , np .ndarray | list | tuple ):
1725+ if start is None :
1726+ raise ValueError (
1727+ "Provided scenario has no index and no start date was provided. This combination "
1728+ "is ambiguous. Please provide a start date, or add an index to the scenario."
1729+ )
1730+ n = x .shape [0 ] if isinstance (x , np .ndarray ) else len (x )
1731+ return pd .RangeIndex (start , n + start , step = 1 , dtype = "int" )
1732+ else :
1733+ raise ValueError (f"{ type (x )} is not a valid type for scenario data." )
1734+
1735+ x0_idx = None
1736+
1737+ if use_scenario_index :
1738+ forecast_index = get_or_create_index (scenario , start )
1739+ is_datetime = isinstance (forecast_index , pd .DatetimeIndex )
1740+
1741+ # If the user provided an index, we want to take it as-is (without removing the start value). Instead,
1742+ # step one back and use this as the start value.
1743+ delta = forecast_index .freq if is_datetime else 1
1744+ x0_idx = forecast_index [0 ] - delta
17101745
17111746 else :
1712- if end is not None :
1713- forecast_index = pd .RangeIndex (start , end , step = 1 , dtype = "int" )
1714- if periods is not None :
1715- forecast_index = pd .RangeIndex (start , start + periods , step = 1 , dtype = "int" )
1747+ # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1748+ # use a range index.
1749+ is_datetime = isinstance (time_index , pd .DatetimeIndex )
1750+ forecast_index = None
1751+
1752+ if is_datetime :
1753+ freq = time_index .inferred_freq
1754+ if isinstance (start , int ):
1755+ start = time_index [start ]
1756+ if end is not None :
1757+ forecast_index = pd .date_range (start , end = end , freq = freq )
1758+ if periods is not None :
1759+ forecast_index = pd .date_range (start , periods = periods , freq = freq )
1760+
1761+ else :
1762+ if end is not None :
1763+ forecast_index = pd .RangeIndex (start , end , step = 1 , dtype = "int" )
1764+ if periods is not None :
1765+ forecast_index = pd .RangeIndex (start , start + periods , step = 1 , dtype = "int" )
1766+
1767+ if is_datetime :
1768+ if forecast_index .freq != time_index .freq :
1769+ raise ValueError (
1770+ "The frequency of the forecast index must match the frequency on the data used "
1771+ f"to fit the model. Got { forecast_index .freq } , expected { time_index .freq } "
1772+ )
1773+
1774+ if x0_idx is None :
1775+ x0_idx , forecast_index = forecast_index [0 ], forecast_index [1 :]
1776+ if x0_idx in forecast_index :
1777+ raise ValueError ("x0_idx should not be in the forecast index" )
1778+ if x0_idx not in time_index :
1779+ raise ValueError ("start must be in the data index used to fit the model." )
17161780
1717- return forecast_index
1781+ # The starting value should not be included in the forecast index. It will be used only to define x0 and P0,
1782+ # and no forecast will be associated with it.
1783+ return x0_idx , forecast_index
17181784
17191785 def _finalize_scenario_initialization (
17201786 self ,
@@ -1876,7 +1942,7 @@ def forecast(
18761942 verbose = verbose ,
18771943 )
18781944
1879- forecast_index = self ._build_forecast_index (
1945+ t0 , forecast_index = self ._build_forecast_index (
18801946 time_index = time_index ,
18811947 start = start ,
18821948 end = end ,
@@ -1892,7 +1958,6 @@ def forecast(
18921958 if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
18931959 dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
18941960
1895- t0 = forecast_index [0 ]
18961961 t0_idx = np .flatnonzero (time_index == t0 )[0 ]
18971962
18981963 temp_coords ["data_time" ] = time_index
0 commit comments