Skip to content
46 changes: 37 additions & 9 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,28 +2203,56 @@ def forecast(

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
scenario=scenario,
data_dims=["data_time", OBS_STATE_DIM],
)

for name in self.data_names:
if name in scenario.keys():
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break

group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

if scenario is not None:
sub_dict = {
forecast_model[data_name]: pt.as_tensor_variable(
x=np.atleast_2d(self._exog_data_info[data_name]["value"].T).T,
name=data_name,
)
for data_name in self.data_names
}

# Will this always be named "data"?
sub_dict[forecast_model["data"]] = pt.as_tensor_variable(
np.atleast_2d(self._fit_data.T).T, name="data"
)
else:
# same here will it always be named data?
sub_dict = {
forecast_model["data"]: pt.as_tensor_variable(
np.atleast_2d(self._fit_data.T).T, name="data"
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify all this by using the model data_vars property. PyMC models store all the symbolic variables created with pm.Data here. They are pytensor SharedVariables, so the data lives inside them. You can get SharedVariable data with .get_value(). That will also ensure the shapes are correct.

So you can do something like:

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data") for data_var in forecast_model.data_vars
} 

We want free absolutely all data, so this should always do what we want. When there's no self.data_names, it should still find the main data (even if we change the name later). When there is, it will freeze that as well.

If you want, you can add a sanity check after that makes sure the names of the variables in the keys of sub_dict are in self.data_names + ['data']


mu, cov = graph_replace([mu, cov], replace=sub_dict, strict=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe give new names, like frozen_mu, or mu_given_training_data? Makes it clear what we're doing here.


x0 = pm.Deterministic(
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
)
P0 = pm.Deterministic(
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

if scenario is not None:
for name in self.data_names:
if name in scenario.keys():
pm.set_data({name: scenario[name]})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pm.set_data can take a dictionary of multiple things, so no need to loop here. Just directly do pm.set_data(scenario). I think the scenario has already been validated by this point (by self._validate_scenario_data), so there's no need to make sure the keys are in self.data_names

I'm pretty sure you can do something like:

if scenario is not None:
    pm.set_data(scenario | {'data': dummy_obs_data},
                coords={"data_time": np.arange(len(forecast_index))})


for name in self.data_names:
if name in scenario.keys():
# same here will it always be named data?
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break

_ = LinearGaussianStateSpace(
"forecast",
x0,
Expand Down