Skip to content

Commit eaf285c

Browse files
Wrap user scenarios in dictionary
1 parent f92782c commit eaf285c

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,16 @@ def forecast(
18551855
)
18561856
start = time_index[-1]
18571857

1858-
scenario = self._validate_scenario_data(scenario, verbose=verbose)
1858+
if not isinstance(scenario, dict):
1859+
if len(self.data_names) > 1:
1860+
raise ValueError(
1861+
"Model needs more than one exogenous data to do forecasting. In this case, you must "
1862+
"pass a dictionary of scenario data."
1863+
)
1864+
[data_name] = self.data_names
1865+
scenario = {data_name: scenario}
1866+
1867+
scenario: dict = self._validate_scenario_data(scenario, verbose=verbose)
18591868

18601869
self._validate_forecast_args(
18611870
time_index=time_index,
@@ -1917,19 +1926,14 @@ def forecast(
19171926
for data_name in self.data_names
19181927
}
19191928

1920-
subbed_matrices = graph_replace(matrices, replace=sub_dict, strict=True)
1921-
[
1922-
setattr(matrix, "name", name)
1923-
for name, matrix in zip(MATRIX_NAMES[2:], subbed_matrices)
1924-
]
1925-
else:
1926-
subbed_matrices = matrices
1929+
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
1930+
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]
19271931

19281932
_ = LinearGaussianStateSpace(
19291933
"forecast",
19301934
x0,
19311935
P0,
1932-
*subbed_matrices,
1936+
*matrices,
19331937
steps=len(forecast_index[:-1]),
19341938
dims=dims,
19351939
mode=self._fit_mode,

0 commit comments

Comments
 (0)