Skip to content

Commit 5b064d4

Browse files
committed
1. removed modelcontext call that is not needed
2. Added handle for when filter_output param is passed in as a str 3. removed case statement in favor of dictionary mapping that already exists in conf.py
1 parent fd87691 commit 5b064d4

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,11 +1681,13 @@ def sample_statespace_matrices(
16811681
def sample_filter_outputs(
16821682
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
16831683
):
1684+
if isinstance(filter_output_names, str):
1685+
filter_output_names = [filter_output_names]
1686+
16841687
compile_kwargs = kwargs.pop("compile_kwargs", {})
16851688
compile_kwargs.setdefault("mode", self.mode)
16861689

16871690
with pm.Model(coords=self.coords) as m:
1688-
pm_mod = modelcontext(None)
16891691
self._build_dummy_graph()
16901692
self._insert_random_variables()
16911693

@@ -1698,7 +1700,7 @@ def sample_filter_outputs(
16981700
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
16991701
data = self._fit_data
17001702

1701-
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
1703+
obs_coords = m.coords.get(OBS_STATE_DIM, None)
17021704

17031705
data, nan_mask = register_data_with_pymc(
17041706
data,
@@ -1724,7 +1726,16 @@ def sample_filter_outputs(
17241726
T, R, Q, filter_outputs[0], filter_outputs[3]
17251727
)
17261728

1729+
# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
1730+
filter_output_dims_mapping = {}
1731+
for k in FILTER_OUTPUT_DIMS.keys():
1732+
filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]
1733+
17271734
all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
1735+
# This excludes observed states and observed covariances from the filter outputs
1736+
all_filter_outputs = [
1737+
output for output in all_filter_outputs if output.name in filter_output_dims_mapping
1738+
]
17281739

17291740
if filter_output_names is None:
17301741
filter_output_names = all_filter_outputs
@@ -1741,16 +1752,7 @@ def sample_filter_outputs(
17411752
]
17421753

17431754
for output in filter_output_names:
1744-
match output.name:
1745-
case "filtered_states" | "predicted_states" | "smoothed_states":
1746-
dims = [TIME_DIM, "state"]
1747-
case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances":
1748-
dims = [TIME_DIM, "state", "state_aux"]
1749-
case "observed_states":
1750-
dims = [TIME_DIM, "observed_state"]
1751-
case "observed_covariances":
1752-
dims = [TIME_DIM, "observed_state", "observed_state_aux"]
1753-
1755+
dims = filter_output_dims_mapping[output.name]
17541756
pm.Deterministic(output.name, output, dims=dims)
17551757

17561758
frozen_model = freeze_dims_and_data(m)

0 commit comments

Comments
 (0)