Skip to content

Commit 2371fec

Browse files
Wrap user scenarios in dictionary
1 parent 07a7f33 commit 2371fec

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,7 +1867,16 @@ def forecast(
18671867
)
18681868
start = time_index[-1]
18691869

1870-
scenario = self._validate_scenario_data(scenario, verbose=verbose)
1870+
if not isinstance(scenario, dict):
1871+
if len(self.data_names) > 1:
1872+
raise ValueError(
1873+
"Model needs more than one exogenous data to do forecasting. In this case, you must "
1874+
"pass a dictionary of scenario data."
1875+
)
1876+
[data_name] = self.data_names
1877+
scenario = {data_name: scenario}
1878+
1879+
scenario: dict = self._validate_scenario_data(scenario, verbose=verbose)
18711880

18721881
self._validate_forecast_args(
18731882
time_index=time_index,
@@ -1929,19 +1938,14 @@ def forecast(
19291938
for data_name in self.data_names
19301939
}
19311940

1932-
subbed_matrices = graph_replace(matrices, replace=sub_dict, strict=True)
1933-
[
1934-
setattr(matrix, "name", name)
1935-
for name, matrix in zip(MATRIX_NAMES[2:], subbed_matrices)
1936-
]
1937-
else:
1938-
subbed_matrices = matrices
1941+
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
1942+
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]
19391943

19401944
_ = LinearGaussianStateSpace(
19411945
"forecast",
19421946
x0,
19431947
P0,
1944-
*subbed_matrices,
1948+
*matrices,
19451949
steps=len(forecast_index[:-1]),
19461950
dims=dims,
19471951
mode=self._fit_mode,

0 commit comments

Comments
 (0)