Skip to content
33 changes: 22 additions & 11 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,28 +2203,39 @@ def forecast(

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
scenario=scenario,
data_dims=["data_time", OBS_STATE_DIM],
)

for name in self.data_names:
if name in scenario.keys():
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break

group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
for data_var in forecast_model.data_vars
}

replacements_diff = np.setdiff1d(
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
)
if replacements_diff.size > 0:
raise ValueError(f"{replacements_diff} data used for fitting not found!")

mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)

x0 = pm.Deterministic(
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
"x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
)
P0 = pm.Deterministic(
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

if scenario is not None:
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
pm.set_data(
scenario | {"data": dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)

_ = LinearGaussianStateSpace(
"forecast",
x0,
Expand Down