@@ -1855,7 +1855,16 @@ def forecast(
1855
1855
)
1856
1856
start = time_index [- 1 ]
1857
1857
1858
- scenario = self ._validate_scenario_data (scenario , verbose = verbose )
1858
+ if not isinstance (scenario , dict ):
1859
+ if len (self .data_names ) > 1 :
1860
+ raise ValueError (
1861
+ "Model needs more than one exogenous data to do forecasting. In this case, you must "
1862
+ "pass a dictionary of scenario data."
1863
+ )
1864
+ [data_name ] = self .data_names
1865
+ scenario = {data_name : scenario }
1866
+
1867
+ scenario : dict = self ._validate_scenario_data (scenario , verbose = verbose )
1859
1868
1860
1869
self ._validate_forecast_args (
1861
1870
time_index = time_index ,
@@ -1917,19 +1926,14 @@ def forecast(
1917
1926
for data_name in self .data_names
1918
1927
}
1919
1928
1920
- subbed_matrices = graph_replace (matrices , replace = sub_dict , strict = True )
1921
- [
1922
- setattr (matrix , "name" , name )
1923
- for name , matrix in zip (MATRIX_NAMES [2 :], subbed_matrices )
1924
- ]
1925
- else :
1926
- subbed_matrices = matrices
1929
+ matrices = graph_replace (matrices , replace = sub_dict , strict = True )
1930
+ [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
1927
1931
1928
1932
_ = LinearGaussianStateSpace (
1929
1933
"forecast" ,
1930
1934
x0 ,
1931
1935
P0 ,
1932
- * subbed_matrices ,
1936
+ * matrices ,
1933
1937
steps = len (forecast_index [:- 1 ]),
1934
1938
dims = dims ,
1935
1939
mode = self ._fit_mode ,
0 commit comments