-
Notifications
You must be signed in to change notification settings - Fork 73
added sample_filter_outputs utility and accompanying simple tests #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
fd87691
5b064d4
d142a91
9e78bae
46149ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari | |
| states, covs = outputs[:4], outputs[4:] | ||
|
|
||
| state_names = [ | ||
| "filtered_state", | ||
| "predicted_state", | ||
| "predicted_observed_state", | ||
| "smoothed_state", | ||
| "filtered_states", | ||
| "predicted_states", | ||
| "predicted_observed_states", | ||
| "smoothed_states", | ||
| ] | ||
| cov_names = [ | ||
| "filtered_covariance", | ||
| "predicted_covariance", | ||
| "predicted_observed_covariance", | ||
| "smoothed_covariance", | ||
| "filtered_covariances", | ||
| "predicted_covariances", | ||
| "predicted_observed_covariances", | ||
| "smoothed_covariances", | ||
| ] | ||
|
|
||
| with mod: | ||
|
|
@@ -939,7 +939,7 @@ def build_statespace_graph( | |
| all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances] | ||
| self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) | ||
|
|
||
| obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"] | ||
| obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] | ||
| obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None | ||
|
|
||
| SequenceMvNormal( | ||
|
|
@@ -1678,6 +1678,83 @@ def sample_statespace_matrices( | |
|
|
||
| return matrix_idata | ||
|
|
||
| def sample_filter_outputs( | ||
| self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs | ||
| ): | ||
| if isinstance(filter_output_names, str): | ||
| filter_output_names = [filter_output_names] | ||
|
|
||
| drop_keys = {"predicted_observed_states", "predicted_observed_covariances"} | ||
|
||
| all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys} | ||
|
|
||
| if filter_output_names is None: | ||
| filter_output_names = list(all_filter_output_dims.keys()) | ||
| else: | ||
| unknown_filter_output_names = np.setdiff1d( | ||
| filter_output_names, list(all_filter_output_dims.keys()) | ||
| ) | ||
| if unknown_filter_output_names.size > 0: | ||
| raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!") | ||
| filter_output_names = [ | ||
| x for x in all_filter_output_dims.keys() if x in filter_output_names | ||
| ] | ||
|
|
||
| compile_kwargs = kwargs.pop("compile_kwargs", {}) | ||
| compile_kwargs.setdefault("mode", self.mode) | ||
|
|
||
| with pm.Model(coords=self.coords) as m: | ||
| self._build_dummy_graph() | ||
| self._insert_random_variables() | ||
|
|
||
| if self.data_names: | ||
| for name in self.data_names: | ||
| pm.Data(**self._exog_data_info[name]) | ||
|
|
||
| self._insert_data_variables() | ||
|
|
||
| x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() | ||
| data = self._fit_data | ||
|
|
||
| obs_coords = m.coords.get(OBS_STATE_DIM, None) | ||
|
|
||
| data, nan_mask = register_data_with_pymc( | ||
| data, | ||
| n_obs=self.ssm.k_endog, | ||
| obs_coords=obs_coords, | ||
| register_data=True, | ||
| ) | ||
|
|
||
| filter_outputs = self.kalman_filter.build_graph( | ||
| data, | ||
| x0, | ||
| P0, | ||
| c, | ||
| d, | ||
| T, | ||
| Z, | ||
| R, | ||
| H, | ||
| Q, | ||
| ) | ||
|
|
||
| smoother_outputs = self.kalman_smoother.build_graph( | ||
| T, R, Q, filter_outputs[0], filter_outputs[3] | ||
| ) | ||
|
|
||
| filter_outputs = filter_outputs[:-1] + list(smoother_outputs) | ||
| for output in filter_outputs: | ||
| if output.name in filter_output_names: | ||
| dims = all_filter_output_dims[output.name] | ||
| pm.Deterministic(output.name, output, dims=dims) | ||
|
|
||
| with freeze_dims_and_data(m): | ||
| return pm.sample_posterior_predictive( | ||
| idata if group == "posterior" else idata.prior, | ||
| var_names=filter_output_names, | ||
| compile_kwargs=compile_kwargs, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _validate_forecast_args( | ||
| time_index: pd.RangeIndex | pd.DatetimeIndex, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.