Skip to content

Commit 6f06a74

Browse files
Dekermanjianjessegrabowski
authored andcommitted
Forecast exogenous vars bug fix (pymc-devs#510)
* fixed bug in statespace forecast method when exogenous variables are present. * updated solution to handle input shapes correctly * simplified fix, renamed mu and cov for transparancy and added a check for the graph replacements * Refactor model builder logic out of `forecast` method * made slight change with _build_forecast_model and created a test case * made change to test_build_forecast_model() to ensure data is replaced with pm.set_data method * added additional checks to test_build_forecast_model * added mock_sample_setup_and_teardown to statespace tests --------- Co-authored-by: jessegrabowski <[email protected]>
1 parent d0db7b5 commit 6f06a74

File tree

2 files changed

+174
-172
lines changed

2 files changed

+174
-172
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization(
20472047

20482048
return scenario
20492049

2050+
def _build_forecast_model(
2051+
self, time_index, t0, forecast_index, scenario, filter_output, mvn_method
2052+
):
2053+
filter_time_dim = TIME_DIM
2054+
temp_coords = self._fit_coords.copy()
2055+
2056+
dims = None
2057+
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2058+
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2059+
2060+
t0_idx = np.flatnonzero(time_index == t0)[0]
2061+
2062+
temp_coords["data_time"] = time_index
2063+
temp_coords[TIME_DIM] = forecast_index
2064+
2065+
mu_dims, cov_dims = None, None
2066+
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2067+
mu_dims = ["data_time", ALL_STATE_DIM]
2068+
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2069+
2070+
with pm.Model(coords=temp_coords) as forecast_model:
2071+
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2072+
data_dims=["data_time", OBS_STATE_DIM],
2073+
)
2074+
2075+
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2076+
mu, cov = grouped_outputs[group_idx]
2077+
2078+
sub_dict = {
2079+
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
2080+
for data_var in forecast_model.data_vars
2081+
}
2082+
2083+
missing_data_vars = np.setdiff1d(
2084+
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
2085+
)
2086+
if missing_data_vars.size > 0:
2087+
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
2088+
2089+
mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)
2090+
2091+
x0 = pm.Deterministic(
2092+
"x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2093+
)
2094+
P0 = pm.Deterministic(
2095+
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2096+
)
2097+
2098+
_ = LinearGaussianStateSpace(
2099+
"forecast",
2100+
x0,
2101+
P0,
2102+
*matrices,
2103+
steps=len(forecast_index),
2104+
dims=dims,
2105+
sequence_names=self.kalman_filter.seq_names,
2106+
k_endog=self.k_endog,
2107+
append_x0=False,
2108+
method=mvn_method,
2109+
)
2110+
2111+
return forecast_model
2112+
20502113
def forecast(
20512114
self,
20522115
idata: InferenceData,
@@ -2139,8 +2202,6 @@ def forecast(
21392202
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
21402203
21412204
"""
2142-
filter_time_dim = TIME_DIM
2143-
21442205
_validate_filter_arg(filter_output)
21452206

21462207
compile_kwargs = kwargs.pop("compile_kwargs", {})
@@ -2185,58 +2246,23 @@ def forecast(
21852246
use_scenario_index=use_scenario_index,
21862247
)
21872248
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
2188-
temp_coords = self._fit_coords.copy()
2189-
2190-
dims = None
2191-
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2192-
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2193-
2194-
t0_idx = np.flatnonzero(time_index == t0)[0]
2195-
2196-
temp_coords["data_time"] = time_index
2197-
temp_coords[TIME_DIM] = forecast_index
2198-
2199-
mu_dims, cov_dims = None, None
2200-
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2201-
mu_dims = ["data_time", ALL_STATE_DIM]
2202-
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2203-
2204-
with pm.Model(coords=temp_coords) as forecast_model:
2205-
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2206-
scenario=scenario,
2207-
data_dims=["data_time", OBS_STATE_DIM],
2208-
)
2209-
2210-
for name in self.data_names:
2211-
if name in scenario.keys():
2212-
pm.set_data(
2213-
{"data": np.zeros((len(forecast_index), self.k_endog))},
2214-
coords={"data_time": np.arange(len(forecast_index))},
2215-
)
2216-
break
22172249

2218-
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2219-
mu, cov = grouped_outputs[group_idx]
2220-
2221-
x0 = pm.Deterministic(
2222-
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2223-
)
2224-
P0 = pm.Deterministic(
2225-
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2226-
)
2250+
forecast_model = self._build_forecast_model(
2251+
time_index=time_index,
2252+
t0=t0,
2253+
forecast_index=forecast_index,
2254+
scenario=scenario,
2255+
filter_output=filter_output,
2256+
mvn_method=mvn_method,
2257+
)
22272258

2228-
_ = LinearGaussianStateSpace(
2229-
"forecast",
2230-
x0,
2231-
P0,
2232-
*matrices,
2233-
steps=len(forecast_index),
2234-
dims=dims,
2235-
sequence_names=self.kalman_filter.seq_names,
2236-
k_endog=self.k_endog,
2237-
append_x0=False,
2238-
method=mvn_method,
2239-
)
2259+
with forecast_model:
2260+
if scenario is not None:
2261+
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2262+
pm.set_data(
2263+
scenario | {"data": dummy_obs_data},
2264+
coords={"data_time": np.arange(len(forecast_index))},
2265+
)
22402266

22412267
forecast_model.rvs_to_initial_values = {
22422268
k: None for k in forecast_model.rvs_to_initial_values.keys()

0 commit comments

Comments
 (0)