@@ -1697,36 +1697,102 @@ def _build_forecast_index(
16971697 periods : int | None = None ,
16981698 use_scenario_index : bool = False ,
16991699 scenario : pd .DataFrame | np .ndarray | None = None ,
1700- ) -> pd .Index :
1701- if use_scenario_index :
1702- if isinstance (scenario , pd .DataFrame ):
1703- return scenario .index
1704- if isinstance (scenario , dict ):
1705- first_df = next (
1706- (df for df in scenario .values () if isinstance (df , pd .DataFrame )), None
1707- )
1708- return first_df .index
1700+ ) -> tuple [int | pd .Timestamp , pd .RangeIndex | pd .DatetimeIndex ]:
1701+ """
1702+ Construct a pandas Index for the requested forecast horizon.
17091703
1710- # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1711- # use a range index.
1712- is_datetime = isinstance (time_index , pd .DatetimeIndex )
1713- forecast_index = None
1704+ Parameters
1705+ ----------
1706+ time_index: pd.RangeIndex or pd.DatetimeIndex
1707+ Index of the data used to fit the model
1708+ start: int or pd.Timestamp, optional
1709+ Date from which to begin forecasting. If using a datetime index, integer start will be interpreted
1710+ as a positional index. Otherwise, start must be found inside the time_index
1711+ end: int or pd.Timestamp, optional
1712+ Date at which to end forecasting. If using a datetime index, end must be a timestamp.
1713+ periods: int, optional
1714+ Number of periods to forecast
1715+ scenario: pd.DataFrame, np.ndarray, optional
1716+ Scenario data to use for forecasting. If provided, the index of the scenario data will be used as the
1717+ forecast index. If provided, start, end, and periods will be ignored.
1718+ use_scenario_index: bool, default False
1719+ If True, the index of the scenario data will be used as the forecast index.
17141720
1715- if is_datetime :
1716- freq = time_index .inferred_freq
17171721
1718- if end is not None :
1719- forecast_index = pd .date_range (start , end = end , freq = freq )
1720- if periods is not None :
1721- forecast_index = pd .date_range (start , periods = periods , freq = freq )
1722+ Returns
1723+ -------
1724+ start: int | pd.TimeStamp
1725+ The starting date index or time step from which to generate the forecasts.
1726+
1727+ forecast_index: pd.DatetimeIndex or pd.RangeIndex
1728+ Index for the forecast results
1729+ """
1730+
1731+ def get_or_create_index (x , start = None ):
1732+ if isinstance (x , pd .DataFrame | pd .Series ):
1733+ return x .index
1734+ elif isinstance (x , dict ):
1735+ return get_or_create_index (next (iter (x .values ())))
1736+ elif isinstance (x , np .ndarray | list | tuple ):
1737+ if start is None :
1738+ raise ValueError (
1739+ "Provided scenario has no index and no start date was provided. This combination "
1740+ "is ambiguous. Please provide a start date, or add an index to the scenario."
1741+ )
1742+ n = x .shape [0 ] if isinstance (x , np .ndarray ) else len (x )
1743+ return pd .RangeIndex (start , n + start , step = 1 , dtype = "int" )
1744+ else :
1745+ raise ValueError (f"{ type (x )} is not a valid type for scenario data." )
1746+
1747+ x0_idx = None
1748+
1749+ if use_scenario_index :
1750+ forecast_index = get_or_create_index (scenario , start )
1751+ is_datetime = isinstance (forecast_index , pd .DatetimeIndex )
1752+
1753+ # If the user provided an index, we want to take it as-is (without removing the start value). Instead,
1754+ # step one back and use this as the start value.
1755+ delta = forecast_index .freq if is_datetime else 1
1756+ x0_idx = forecast_index [0 ] - delta
17221757
17231758 else :
1724- if end is not None :
1725- forecast_index = pd .RangeIndex (start , end , step = 1 , dtype = "int" )
1726- if periods is not None :
1727- forecast_index = pd .RangeIndex (start , start + periods , step = 1 , dtype = "int" )
1759+ # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1760+ # use a range index.
1761+ is_datetime = isinstance (time_index , pd .DatetimeIndex )
1762+ forecast_index = None
1763+
1764+ if is_datetime :
1765+ freq = time_index .inferred_freq
1766+ if isinstance (start , int ):
1767+ start = time_index [start ]
1768+ if end is not None :
1769+ forecast_index = pd .date_range (start , end = end , freq = freq )
1770+ if periods is not None :
1771+ forecast_index = pd .date_range (start , periods = periods , freq = freq )
1772+
1773+ else :
1774+ if end is not None :
1775+ forecast_index = pd .RangeIndex (start , end , step = 1 , dtype = "int" )
1776+ if periods is not None :
1777+ forecast_index = pd .RangeIndex (start , start + periods , step = 1 , dtype = "int" )
1778+
1779+ if is_datetime :
1780+ if forecast_index .freq != time_index .freq :
1781+ raise ValueError (
1782+ "The frequency of the forecast index must match the frequency on the data used "
1783+ f"to fit the model. Got { forecast_index .freq } , expected { time_index .freq } "
1784+ )
1785+
1786+ if x0_idx is None :
1787+ x0_idx , forecast_index = forecast_index [0 ], forecast_index [1 :]
1788+ if x0_idx in forecast_index :
1789+ raise ValueError ("x0_idx should not be in the forecast index" )
1790+ if x0_idx not in time_index :
1791+ raise ValueError ("start must be in the data index used to fit the model." )
17281792
1729- return forecast_index
1793+ # The starting value should not be included in the forecast index. It will be used only to define x0 and P0,
1794+ # and no forecast will be associated with it.
1795+ return x0_idx , forecast_index
17301796
17311797 def _finalize_scenario_initialization (
17321798 self ,
@@ -1888,7 +1954,7 @@ def forecast(
18881954 verbose = verbose ,
18891955 )
18901956
1891- forecast_index = self ._build_forecast_index (
1957+ t0 , forecast_index = self ._build_forecast_index (
18921958 time_index = time_index ,
18931959 start = start ,
18941960 end = end ,
@@ -1904,7 +1970,6 @@ def forecast(
19041970 if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
19051971 dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
19061972
1907- t0 = forecast_index [0 ]
19081973 t0_idx = np .flatnonzero (time_index == t0 )[0 ]
19091974
19101975 temp_coords ["data_time" ] = time_index
0 commit comments