Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,21 @@ def sample_filter_outputs(
if isinstance(filter_output_names, str):
filter_output_names = [filter_output_names]

drop_keys = {"predicted_observed_states", "predicted_observed_covariances"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't treat these as special (even though I agree it's silly to ask for them). I'd be confused if I tried to ask for them and it said it's not a valid filter output name.

Having everything in one place is convenient, even if it's duplicative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, yeah I agree with you.

I just wasn't sure because in constants.py FILTER_OUTPUT_DIMS has predicted_observed_states and predicted_observed_covariances but the output from kalman_filter.build_graph() doesn't have predicted_observed_states and predicted_observed_covariances it seems like these are named observed_states and observed_covariances.

Should I change the names in constants.py to match the returned names from kalman_filter.build_graph()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these should be consistent. But where does the name change currently happen between the filter and the idata? Maybe this is an issue for another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski, I believe this happens in _postprocess_scan_results() in kalman_filter.py. It looks like the names of the filter outputs are hardcoded in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to make this consistent in this PR I have no objection. I don't have a good sense if you should change FILTER_OUTPUT_DIMS to match the output names, or change the output names to match the FILTER_OUTPUT_DIMS. I'll defer to you if you have a sense of which one is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will do it in this PR because I think it is somewhat related. I think the names should match whatever we put in FILTER_OUTPUT_DIMS

all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys}

if filter_output_names is None:
filter_output_names = list(all_filter_output_dims.keys())
else:
unknown_filter_output_names = np.setdiff1d(
filter_output_names, list(all_filter_output_dims.keys())
)
if unknown_filter_output_names.size > 0:
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
filter_output_names = [
x for x in all_filter_output_dims.keys() if x in filter_output_names
]

compile_kwargs = kwargs.pop("compile_kwargs", {})
compile_kwargs.setdefault("mode", self.mode)

Expand Down Expand Up @@ -1726,44 +1741,19 @@ def sample_filter_outputs(
T, R, Q, filter_outputs[0], filter_outputs[3]
)

# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
# filter_output_dims_mapping = {}
# for k in FILTER_OUTPUT_DIMS.keys():
# filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]

all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
# This excludes observed states and observed covariances from the filter outputs
all_filter_outputs = [
output for output in all_filter_outputs if output.name in FILTER_OUTPUT_DIMS
]
filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
for output in filter_outputs:
if output.name in filter_output_names:
dims = all_filter_output_dims[output.name]
pm.Deterministic(output.name, output, dims=dims)

if filter_output_names is None:
filter_output_names = all_filter_outputs
else:
unknown_filter_output_names = np.setdiff1d(
filter_output_names, [x.name for x in all_filter_outputs]
)
if unknown_filter_output_names.size > 0:
raise ValueError(
f"{unknown_filter_output_names} not a valid filter output name!"
)
filter_output_names = [
x for x in all_filter_outputs if x.name in filter_output_names
]

for output in filter_output_names:
dims = FILTER_OUTPUT_DIMS[output.name]
pm.Deterministic(output.name, output, dims=dims)

frozen_model = freeze_dims_and_data(m)
with frozen_model:
idata_filter = pm.sample_posterior_predictive(
with freeze_dims_and_data(m):
return pm.sample_posterior_predictive(
idata if group == "posterior" else idata.prior,
var_names=[x.name for x in frozen_model.deterministics],
var_names=filter_output_names,
compile_kwargs=compile_kwargs,
**kwargs,
)
return idata_filter

@staticmethod
def _validate_forecast_args(
Expand Down