@@ -1855,7 +1855,16 @@ def forecast(
18551855 )
18561856 start = time_index [- 1 ]
18571857
1858- scenario = self ._validate_scenario_data (scenario , verbose = verbose )
1858+ if not isinstance (scenario , dict ):
1859+ if len (self .data_names ) > 1 :
1860+ raise ValueError (
1861+ "Model needs more than one exogenous data to do forecasting. In this case, you must "
1862+ "pass a dictionary of scenario data."
1863+ )
1864+ [data_name ] = self .data_names
1865+ scenario = {data_name : scenario }
1866+
1867+ scenario : dict = self ._validate_scenario_data (scenario , verbose = verbose )
18591868
18601869 self ._validate_forecast_args (
18611870 time_index = time_index ,
@@ -1917,19 +1926,14 @@ def forecast(
19171926 for data_name in self .data_names
19181927 }
19191928
1920- subbed_matrices = graph_replace (matrices , replace = sub_dict , strict = True )
1921- [
1922- setattr (matrix , "name" , name )
1923- for name , matrix in zip (MATRIX_NAMES [2 :], subbed_matrices )
1924- ]
1925- else :
1926- subbed_matrices = matrices
1929+ matrices = graph_replace (matrices , replace = sub_dict , strict = True )
1930+ [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
19271931
19281932 _ = LinearGaussianStateSpace (
19291933 "forecast" ,
19301934 x0 ,
19311935 P0 ,
1932- * subbed_matrices ,
1936+ * matrices ,
19331937 steps = len (forecast_index [:- 1 ]),
19341938 dims = dims ,
19351939 mode = self ._fit_mode ,
0 commit comments