@@ -1685,36 +1685,102 @@ def _build_forecast_index(
1685
1685
periods : int | None = None ,
1686
1686
use_scenario_index : bool = False ,
1687
1687
scenario : pd .DataFrame | np .ndarray | None = None ,
1688
- ) -> pd .Index :
1689
- if use_scenario_index :
1690
- if isinstance (scenario , pd .DataFrame ):
1691
- return scenario .index
1692
- if isinstance (scenario , dict ):
1693
- first_df = next (
1694
- (df for df in scenario .values () if isinstance (df , pd .DataFrame )), None
1695
- )
1696
- return first_df .index
1688
+ ) -> tuple [int | pd .Timestamp , pd .RangeIndex | pd .DatetimeIndex ]:
1689
+ """
1690
+ Construct a pandas Index for the requested forecast horizon.
1697
1691
1698
- # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1699
- # use a range index.
1700
- is_datetime = isinstance (time_index , pd .DatetimeIndex )
1701
- forecast_index = None
1692
+ Parameters
1693
+ ----------
1694
+ time_index: pd.RangeIndex or pd.DatetimeIndex
1695
+ Index of the data used to fit the model
1696
+ start: int or pd.Timestamp, optional
1697
+ Date from which to begin forecasting. If using a datetime index, integer start will be interpreted
1698
+ as a positional index. Otherwise, start must be found inside the time_index
1699
+ end: int or pd.Timestamp, optional
1700
+ Date at which to end forecasting. If using a datetime index, end must be a timestamp.
1701
+ periods: int, optional
1702
+ Number of periods to forecast
1703
+ scenario: pd.DataFrame, np.ndarray, optional
1704
+ Scenario data to use for forecasting. If provided, the index of the scenario data will be used as the
1705
+ forecast index. If provided, start, end, and periods will be ignored.
1706
+ use_scenario_index: bool, default False
1707
+ If True, the index of the scenario data will be used as the forecast index.
1702
1708
1703
- if is_datetime :
1704
- freq = time_index .inferred_freq
1705
1709
1706
- if end is not None :
1707
- forecast_index = pd .date_range (start , end = end , freq = freq )
1708
- if periods is not None :
1709
- forecast_index = pd .date_range (start , periods = periods , freq = freq )
1710
+ Returns
1711
+ -------
1712
+ start: int | pd.TimeStamp
1713
+ The starting date index or time step from which to generate the forecasts.
1714
+
1715
+ forecast_index: pd.DatetimeIndex or pd.RangeIndex
1716
+ Index for the forecast results
1717
+ """
1718
+
1719
+ def get_or_create_index (x , start = None ):
1720
+ if isinstance (x , pd .DataFrame | pd .Series ):
1721
+ return x .index
1722
+ elif isinstance (x , dict ):
1723
+ return get_or_create_index (next (iter (x .values ())))
1724
+ elif isinstance (x , np .ndarray | list | tuple ):
1725
+ if start is None :
1726
+ raise ValueError (
1727
+ "Provided scenario has no index and no start date was provided. This combination "
1728
+ "is ambiguous. Please provide a start date, or add an index to the scenario."
1729
+ )
1730
+ n = x .shape [0 ] if isinstance (x , np .ndarray ) else len (x )
1731
+ return pd .RangeIndex (start , n + start , step = 1 , dtype = "int" )
1732
+ else :
1733
+ raise ValueError (f"{ type (x )} is not a valid type for scenario data." )
1734
+
1735
+ x0_idx = None
1736
+
1737
+ if use_scenario_index :
1738
+ forecast_index = get_or_create_index (scenario , start )
1739
+ is_datetime = isinstance (forecast_index , pd .DatetimeIndex )
1740
+
1741
+ # If the user provided an index, we want to take it as-is (without removing the start value). Instead,
1742
+ # step one back and use this as the start value.
1743
+ delta = forecast_index .freq if is_datetime else 1
1744
+ x0_idx = forecast_index [0 ] - delta
1710
1745
1711
1746
else :
1712
- if end is not None :
1713
- forecast_index = pd .RangeIndex (start , end , step = 1 , dtype = "int" )
1714
- if periods is not None :
1715
- forecast_index = pd .RangeIndex (start , start + periods , step = 1 , dtype = "int" )
1747
+ # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise
1748
+ # use a range index.
1749
+ is_datetime = isinstance (time_index , pd .DatetimeIndex )
1750
+ forecast_index = None
1751
+
1752
+ if is_datetime :
1753
+ freq = time_index .inferred_freq
1754
+ if isinstance (start , int ):
1755
+ start = time_index [start ]
1756
+ if end is not None :
1757
+ forecast_index = pd .date_range (start , end = end , freq = freq )
1758
+ if periods is not None :
1759
+ forecast_index = pd .date_range (start , periods = periods , freq = freq )
1760
+
1761
+ else :
1762
+ if end is not None :
1763
+ forecast_index = pd .RangeIndex (start , end , step = 1 , dtype = "int" )
1764
+ if periods is not None :
1765
+ forecast_index = pd .RangeIndex (start , start + periods , step = 1 , dtype = "int" )
1766
+
1767
+ if is_datetime :
1768
+ if forecast_index .freq != time_index .freq :
1769
+ raise ValueError (
1770
+ "The frequency of the forecast index must match the frequency on the data used "
1771
+ f"to fit the model. Got { forecast_index .freq } , expected { time_index .freq } "
1772
+ )
1773
+
1774
+ if x0_idx is None :
1775
+ x0_idx , forecast_index = forecast_index [0 ], forecast_index [1 :]
1776
+ if x0_idx in forecast_index :
1777
+ raise ValueError ("x0_idx should not be in the forecast index" )
1778
+ if x0_idx not in time_index :
1779
+ raise ValueError ("start must be in the data index used to fit the model." )
1716
1780
1717
- return forecast_index
1781
+ # The starting value should not be included in the forecast index. It will be used only to define x0 and P0,
1782
+ # and no forecast will be associated with it.
1783
+ return x0_idx , forecast_index
1718
1784
1719
1785
def _finalize_scenario_initialization (
1720
1786
self ,
@@ -1876,7 +1942,7 @@ def forecast(
1876
1942
verbose = verbose ,
1877
1943
)
1878
1944
1879
- forecast_index = self ._build_forecast_index (
1945
+ t0 , forecast_index = self ._build_forecast_index (
1880
1946
time_index = time_index ,
1881
1947
start = start ,
1882
1948
end = end ,
@@ -1892,7 +1958,6 @@ def forecast(
1892
1958
if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
1893
1959
dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
1894
1960
1895
- t0 = forecast_index [0 ]
1896
1961
t0_idx = np .flatnonzero (time_index == t0 )[0 ]
1897
1962
1898
1963
temp_coords ["data_time" ] = time_index
0 commit comments