@@ -1867,7 +1867,16 @@ def forecast(
18671867 )
18681868 start = time_index [- 1 ]
18691869
1870- scenario = self ._validate_scenario_data (scenario , verbose = verbose )
1870+ if not isinstance (scenario , dict ):
1871+ if len (self .data_names ) > 1 :
1872+ raise ValueError (
1873+ "Model needs more than one exogenous data to do forecasting. In this case, you must "
1874+ "pass a dictionary of scenario data."
1875+ )
1876+ [data_name ] = self .data_names
1877+ scenario = {data_name : scenario }
1878+
1879+ scenario : dict = self ._validate_scenario_data (scenario , verbose = verbose )
18711880
18721881 self ._validate_forecast_args (
18731882 time_index = time_index ,
@@ -1929,19 +1938,14 @@ def forecast(
19291938 for data_name in self .data_names
19301939 }
19311940
1932- subbed_matrices = graph_replace (matrices , replace = sub_dict , strict = True )
1933- [
1934- setattr (matrix , "name" , name )
1935- for name , matrix in zip (MATRIX_NAMES [2 :], subbed_matrices )
1936- ]
1937- else :
1938- subbed_matrices = matrices
1941+ matrices = graph_replace (matrices , replace = sub_dict , strict = True )
1942+ [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
19391943
19401944 _ = LinearGaussianStateSpace (
19411945 "forecast" ,
19421946 x0 ,
19431947 P0 ,
1944- * subbed_matrices ,
1948+ * matrices ,
19451949 steps = len (forecast_index [:- 1 ]),
19461950 dims = dims ,
19471951 mode = self ._fit_mode ,
0 commit comments