-
Notifications
You must be signed in to change notification settings - Fork 69
added sample_filter_outputs utility and accompanying simple tests #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
fd87691
5b064d4
d142a91
9e78bae
46149ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari | |
states, covs = outputs[:4], outputs[4:] | ||
|
||
state_names = [ | ||
"filtered_state", | ||
"predicted_state", | ||
"predicted_observed_state", | ||
"smoothed_state", | ||
"filtered_states", | ||
"predicted_states", | ||
"predicted_observed_states", | ||
"smoothed_states", | ||
] | ||
cov_names = [ | ||
"filtered_covariance", | ||
"predicted_covariance", | ||
"predicted_observed_covariance", | ||
"smoothed_covariance", | ||
"filtered_covariances", | ||
"predicted_covariances", | ||
"predicted_observed_covariances", | ||
"smoothed_covariances", | ||
] | ||
|
||
with mod: | ||
|
@@ -939,7 +939,7 @@ def build_statespace_graph( | |
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances] | ||
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) | ||
|
||
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"] | ||
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] | ||
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None | ||
|
||
SequenceMvNormal( | ||
|
@@ -1678,6 +1678,83 @@ def sample_statespace_matrices( | |
|
||
return matrix_idata | ||
|
||
def sample_filter_outputs( | ||
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs | ||
): | ||
if isinstance(filter_output_names, str): | ||
filter_output_names = [filter_output_names] | ||
|
||
drop_keys = {"predicted_observed_states", "predicted_observed_covariances"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we shouldn't treat these as special (even though I agree it's silly to ask for them). I'd be confused if I tried to ask for them and it said it's not a valid filter output name. Having everything in one place is convenient, even if it's duplicative. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, yeah I agree with you. I just wasn't sure because in Should I change the names in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, these should be consistent. But where does the name change currently happen between the filter and the idata? Maybe this is an issue for another PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jessegrabowski, I believe this happens in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to make this consistent in this PR I have no objection. I don't have a good sense if you should change FILTER_OUTPUT_DIMS to match the output names, or change the output names to match the FILTER_OUTPUT_DIMS. I'll defer to you if you have a sense of which one is better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I will do it in this PR because I think it is somewhat related. I think the names should match whatever we put in |
||
all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys} | ||
|
||
if filter_output_names is None: | ||
filter_output_names = list(all_filter_output_dims.keys()) | ||
else: | ||
unknown_filter_output_names = np.setdiff1d( | ||
filter_output_names, list(all_filter_output_dims.keys()) | ||
) | ||
if unknown_filter_output_names.size > 0: | ||
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!") | ||
filter_output_names = [ | ||
x for x in all_filter_output_dims.keys() if x in filter_output_names | ||
] | ||
|
||
compile_kwargs = kwargs.pop("compile_kwargs", {}) | ||
compile_kwargs.setdefault("mode", self.mode) | ||
|
||
with pm.Model(coords=self.coords) as m: | ||
self._build_dummy_graph() | ||
self._insert_random_variables() | ||
|
||
if self.data_names: | ||
for name in self.data_names: | ||
pm.Data(**self._exog_data_info[name]) | ||
|
||
self._insert_data_variables() | ||
|
||
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() | ||
data = self._fit_data | ||
|
||
obs_coords = m.coords.get(OBS_STATE_DIM, None) | ||
|
||
data, nan_mask = register_data_with_pymc( | ||
data, | ||
n_obs=self.ssm.k_endog, | ||
obs_coords=obs_coords, | ||
register_data=True, | ||
) | ||
|
||
filter_outputs = self.kalman_filter.build_graph( | ||
data, | ||
x0, | ||
P0, | ||
c, | ||
d, | ||
T, | ||
Z, | ||
R, | ||
H, | ||
Q, | ||
) | ||
|
||
smoother_outputs = self.kalman_smoother.build_graph( | ||
T, R, Q, filter_outputs[0], filter_outputs[3] | ||
) | ||
|
||
filter_outputs = filter_outputs[:-1] + list(smoother_outputs) | ||
for output in filter_outputs: | ||
if output.name in filter_output_names: | ||
dims = all_filter_output_dims[output.name] | ||
pm.Deterministic(output.name, output, dims=dims) | ||
|
||
with freeze_dims_and_data(m): | ||
return pm.sample_posterior_predictive( | ||
idata if group == "posterior" else idata.prior, | ||
var_names=filter_output_names, | ||
compile_kwargs=compile_kwargs, | ||
**kwargs, | ||
) | ||
|
||
@staticmethod | ||
def _validate_forecast_args( | ||
time_index: pd.RangeIndex | pd.DatetimeIndex, | ||
|
Uh oh!
There was an error while loading. Please reload this page.