From fd87691d239346c3a51d7af1b61e21d0edaabb16 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Tue, 17 Jun 2025 10:18:07 -0600 Subject: [PATCH 1/5] added sample_filter_outputs utility and accompanying simple tests Rebased from upstream --- pymc_extras/statespace/core/statespace.py | 85 +++++++++++++++++++++++ tests/statespace/core/test_statespace.py | 29 ++++++++ 2 files changed, 114 insertions(+) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 96e1e9b52..76ec71e40 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1678,6 +1678,91 @@ def sample_statespace_matrices( return matrix_idata + def sample_filter_outputs( + self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs + ): + compile_kwargs = kwargs.pop("compile_kwargs", {}) + compile_kwargs.setdefault("mode", self.mode) + + with pm.Model(coords=self.coords) as m: + pm_mod = modelcontext(None) + self._build_dummy_graph() + self._insert_random_variables() + + if self.data_names: + for name in self.data_names: + pm.Data(**self._exog_data_info[name]) + + self._insert_data_variables() + + x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() + data = self._fit_data + + obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) + + data, nan_mask = register_data_with_pymc( + data, + n_obs=self.ssm.k_endog, + obs_coords=obs_coords, + register_data=True, + ) + + filter_outputs = self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ) + + smoother_outputs = self.kalman_smoother.build_graph( + T, R, Q, filter_outputs[0], filter_outputs[3] + ) + + all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs) + + if filter_output_names is None: + filter_output_names = all_filter_outputs + else: + unknown_filter_output_names = np.setdiff1d( + filter_output_names, [x.name for x in all_filter_outputs] + ) + if unknown_filter_output_names.size > 0: + raise ValueError( + f"{unknown_filter_output_names} not a valid filter output name!" + ) + filter_output_names = [ + x for x in all_filter_outputs if x.name in filter_output_names + ] + + for output in filter_output_names: + match output.name: + case "filtered_states" | "predicted_states" | "smoothed_states": + dims = [TIME_DIM, "state"] + case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances": + dims = [TIME_DIM, "state", "state_aux"] + case "observed_states": + dims = [TIME_DIM, "observed_state"] + case "observed_covariances": + dims = [TIME_DIM, "observed_state", "observed_state_aux"] + + pm.Deterministic(output.name, output, dims=dims) + + frozen_model = freeze_dims_and_data(m) + with frozen_model: + idata_filter = pm.sample_posterior_predictive( + idata if group == "posterior" else idata.prior, + var_names=[x.name for x in frozen_model.deterministics], + compile_kwargs=compile_kwargs, + **kwargs, + ) + return idata_filter + @staticmethod def _validate_forecast_args( time_index: pd.RangeIndex | pd.DatetimeIndex, diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 63b076bd9..3e1e80c0e 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -1,3 +1,5 @@ +import re + from collections.abc import Sequence from functools import partial @@ -1244,3 +1246,30 @@ def test_param_dims_coords(ss_mod_multi_component): assert i == len( ss_mod_multi_component.coords[s] ), f"Mismatch between shape {i} and dimension {s}" + + +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") +@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") +def test_sample_filter_outputs(rng, exog_ss_mod, idata_exog): + # Simple tests + idata_filter_prior = exog_ss_mod.sample_filter_outputs( + idata_exog, filter_output_names=None, group="prior" + ) + + specific_outputs = ["filtered_states", "filtered_covariances"] + idata_filter_specific = exog_ss_mod.sample_filter_outputs( + idata_exog, filter_output_names=specific_outputs + ) + missing_outputs = np.setdiff1d( + specific_outputs, [x for x in idata_filter_specific.posterior_predictive.data_vars] + ) + + assert missing_outputs.size == 0 + + msg = "['filter_covariances' 'filter_states'] not a valid filter output name!" + incorrect_outputs = ["filter_states", "filter_covariances"] + with pytest.raises(ValueError, match=re.escape(msg)): + exog_ss_mod.sample_filter_outputs(idata_exog, filter_output_names=incorrect_outputs) From 5b064d4be67e5baa037fd4268fdd3fef84e5cc28 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Mon, 21 Jul 2025 15:39:38 -0600 Subject: [PATCH 2/5] 1. removed modelcontext call that is not needed 2. Added handle for when filter_output param is passed in as a str 3. removed case statement in favor of dictionary mapping that already exists in conf.py --- pymc_extras/statespace/core/statespace.py | 26 ++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 76ec71e40..55dee3b1d 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1681,11 +1681,13 @@ def sample_statespace_matrices( def sample_filter_outputs( self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs ): + if isinstance(filter_output_names, str): + filter_output_names = [filter_output_names] + compile_kwargs = kwargs.pop("compile_kwargs", {}) compile_kwargs.setdefault("mode", self.mode) with pm.Model(coords=self.coords) as m: - pm_mod = modelcontext(None) self._build_dummy_graph() self._insert_random_variables() @@ -1698,7 +1700,7 @@ def sample_filter_outputs( x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() data = self._fit_data - obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) + obs_coords = m.coords.get(OBS_STATE_DIM, None) data, nan_mask = register_data_with_pymc( data, @@ -1724,7 +1726,16 @@ def sample_filter_outputs( T, R, Q, filter_outputs[0], filter_outputs[3] ) + # Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph() + filter_output_dims_mapping = {} + for k in FILTER_OUTPUT_DIMS.keys(): + filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k] + all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs) + # This excludes observed states and observed covariances from the filter outputs + all_filter_outputs = [ + output for output in all_filter_outputs if output.name in filter_output_dims_mapping + ] if filter_output_names is None: filter_output_names = all_filter_outputs @@ -1741,16 +1752,7 @@ def sample_filter_outputs( ] for output in filter_output_names: - match output.name: - case "filtered_states" | "predicted_states" | "smoothed_states": - dims = [TIME_DIM, "state"] - case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances": - dims = [TIME_DIM, "state", "state_aux"] - case "observed_states": - dims = [TIME_DIM, "observed_state"] - case "observed_covariances": - dims = [TIME_DIM, "observed_state", "observed_state_aux"] - + dims = filter_output_dims_mapping[output.name] pm.Deterministic(output.name, output, dims=dims) frozen_model = freeze_dims_and_data(m) From d142a911dede372f780566a309a0ba08be680af8 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Tue, 22 Jul 2025 16:20:12 -0600 Subject: [PATCH 3/5] updated plurality for some of the constants in constants.py --- pymc_extras/statespace/core/statespace.py | 28 +++++++++---------- pymc_extras/statespace/utils/constants.py | 28 +++++++++---------- tests/statespace/core/test_statespace.py | 14 +++++----- tests/statespace/models/test_SARIMAX.py | 2 +- tests/statespace/models/test_VARMAX.py | 2 +- .../statespace/utils/test_coord_assignment.py | 2 +- 6 files changed, 38 insertions(+), 38 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 55dee3b1d..bd1cbd31d 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari states, covs = outputs[:4], outputs[4:] state_names = [ - "filtered_state", - "predicted_state", - "predicted_observed_state", - "smoothed_state", + "filtered_states", + "predicted_states", + "predicted_observed_states", + "smoothed_states", ] cov_names = [ - "filtered_covariance", - "predicted_covariance", - "predicted_observed_covariance", - "smoothed_covariance", + "filtered_covariances", + "predicted_covariances", + "predicted_observed_covariances", + "smoothed_covariances", ] with mod: @@ -939,7 +939,7 @@ def build_statespace_graph( all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances] self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) - obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"] + obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None SequenceMvNormal( @@ -1727,14 +1727,14 @@ def sample_filter_outputs( ) # Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph() - filter_output_dims_mapping = {} - for k in FILTER_OUTPUT_DIMS.keys(): - filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k] + # filter_output_dims_mapping = {} + # for k in FILTER_OUTPUT_DIMS.keys(): + # filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k] all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs) # This excludes observed states and observed covariances from the filter outputs all_filter_outputs = [ - output for output in all_filter_outputs if output.name in filter_output_dims_mapping + output for output in all_filter_outputs if output.name in FILTER_OUTPUT_DIMS ] if filter_output_names is None: @@ -1752,7 +1752,7 @@ def sample_filter_outputs( ] for output in filter_output_names: - dims = filter_output_dims_mapping[output.name] + dims = FILTER_OUTPUT_DIMS[output.name] pm.Deterministic(output.name, output, dims=dims) frozen_model = freeze_dims_and_data(m) diff --git a/pymc_extras/statespace/utils/constants.py b/pymc_extras/statespace/utils/constants.py index 3e9a3958b..a4ec80fb7 100644 --- a/pymc_extras/statespace/utils/constants.py +++ b/pymc_extras/statespace/utils/constants.py @@ -38,14 +38,14 @@ LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES)) FILTER_OUTPUT_NAMES = [ - "filtered_state", - "predicted_state", - "filtered_covariance", - "predicted_covariance", + "filtered_states", + "predicted_states", + "filtered_covariances", + "predicted_covariances", ] -SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"] -OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"] +SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"] +OBSERVED_OUTPUT_NAMES = ["predicted_observed_states", "predicted_observed_covariances"] MATRIX_DIMS = { "x0": (ALL_STATE_DIM,), @@ -60,14 +60,14 @@ } FILTER_OUTPUT_DIMS = { - "filtered_state": (TIME_DIM, ALL_STATE_DIM), - "smoothed_state": (TIME_DIM, ALL_STATE_DIM), - "predicted_state": (TIME_DIM, ALL_STATE_DIM), - "filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), - "smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), - "predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), - "predicted_observed_state": (TIME_DIM, OBS_STATE_DIM), - "predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM), + "filtered_states": (TIME_DIM, ALL_STATE_DIM), + "smoothed_states": (TIME_DIM, ALL_STATE_DIM), + "predicted_states": (TIME_DIM, ALL_STATE_DIM), + "filtered_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), + "smoothed_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), + "predicted_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), + "predicted_observed_states": (TIME_DIM, OBS_STATE_DIM), + "predicted_observed_covariances": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM), } POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"] diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 3e1e80c0e..06aec484e 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -487,16 +487,16 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill(): def test_build_statespace_graph(pymc_mod): for name in [ - "filtered_state", - "predicted_state", - "predicted_covariance", - "filtered_covariance", + "filtered_states", + "predicted_states", + "predicted_covariances", + "filtered_covariances", ]: assert name in [x.name for x in pymc_mod.deterministics] def test_build_smoother_graph(ss_mod, pymc_mod): - names = ["smoothed_state", "smoothed_covariance"] + names = ["smoothed_states", "smoothed_covariances"] for name in names: assert name in [x.name for x in pymc_mod.deterministics] @@ -1193,11 +1193,11 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_ # Check that the frozen states and covariances correctly match the sliced index np.testing.assert_allclose( - idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values, + idata_exog.posterior["predicted_covariances"].sel(time=t0).mean(("chain", "draw")).values, idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values, ) np.testing.assert_allclose( - idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values, + idata_exog.posterior["predicted_states"].sel(time=t0).mean(("chain", "draw")).values, idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values, ) diff --git a/tests/statespace/models/test_SARIMAX.py b/tests/statespace/models/test_SARIMAX.py index 693b367d8..d04303b2f 100644 --- a/tests/statespace/models/test_SARIMAX.py +++ b/tests/statespace/models/test_SARIMAX.py @@ -321,7 +321,7 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng): @pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"]) def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): - rv = pymc_mod[f"{filter_output}_covariance"] + rv = pymc_mod[f"{filter_output}_covariances"] cov_mats = pm.draw(rv, 100, random_seed=rng) w, v = np.linalg.eig(cov_mats) assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}") diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index 31f112772..fbd0cfc04 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -156,7 +156,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng): @pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"]) def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): - rv = pymc_mod[f"{filter_output}_covariance"] + rv = pymc_mod[f"{filter_output}_covariances"] cov_mats = pm.draw(rv, 100, random_seed=rng) w, v = np.linalg.eig(cov_mats) assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}") diff --git a/tests/statespace/utils/test_coord_assignment.py b/tests/statespace/utils/test_coord_assignment.py index 22f98e921..4f27d4250 100644 --- a/tests/statespace/utils/test_coord_assignment.py +++ b/tests/statespace/utils/test_coord_assignment.py @@ -93,7 +93,7 @@ def test_filter_output_coord_assignment(f, warning, create_model): with warning: pymc_model = create_model(f) - for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]: + for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_states"]: assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output] From 9e78bae46afe7853375282ca4f0e85131362eef3 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Fri, 25 Jul 2025 07:20:20 -0600 Subject: [PATCH 4/5] cleaned up commented code, moved internal checks to the top, reduced intermediate variables --- pymc_extras/statespace/core/statespace.py | 56 ++++++++++------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index bd1cbd31d..c1ebe683c 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1684,6 +1684,21 @@ def sample_filter_outputs( if isinstance(filter_output_names, str): filter_output_names = [filter_output_names] + drop_keys = {"predicted_observed_states", "predicted_observed_covariances"} + all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys} + + if filter_output_names is None: + filter_output_names = list(all_filter_output_dims.keys()) + else: + unknown_filter_output_names = np.setdiff1d( + filter_output_names, list(all_filter_output_dims.keys()) + ) + if unknown_filter_output_names.size > 0: + raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!") + filter_output_names = [ + x for x in all_filter_output_dims.keys() if x in filter_output_names + ] + compile_kwargs = kwargs.pop("compile_kwargs", {}) compile_kwargs.setdefault("mode", self.mode) @@ -1726,44 +1741,19 @@ def sample_filter_outputs( T, R, Q, filter_outputs[0], filter_outputs[3] ) - # Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph() - # filter_output_dims_mapping = {} - # for k in FILTER_OUTPUT_DIMS.keys(): - # filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k] - - all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs) - # This excludes observed states and observed covariances from the filter outputs - all_filter_outputs = [ - output for output in all_filter_outputs if output.name in FILTER_OUTPUT_DIMS - ] + filter_outputs = filter_outputs[:-1] + list(smoother_outputs) + for output in filter_outputs: + if output.name in filter_output_names: + dims = all_filter_output_dims[output.name] + pm.Deterministic(output.name, output, dims=dims) - if filter_output_names is None: - filter_output_names = all_filter_outputs - else: - unknown_filter_output_names = np.setdiff1d( - filter_output_names, [x.name for x in all_filter_outputs] - ) - if unknown_filter_output_names.size > 0: - raise ValueError( - f"{unknown_filter_output_names} not a valid filter output name!" - ) - filter_output_names = [ - x for x in all_filter_outputs if x.name in filter_output_names - ] - - for output in filter_output_names: - dims = FILTER_OUTPUT_DIMS[output.name] - pm.Deterministic(output.name, output, dims=dims) - - frozen_model = freeze_dims_and_data(m) - with frozen_model: - idata_filter = pm.sample_posterior_predictive( + with freeze_dims_and_data(m): + return pm.sample_posterior_predictive( idata if group == "posterior" else idata.prior, - var_names=[x.name for x in frozen_model.deterministics], + var_names=filter_output_names, compile_kwargs=compile_kwargs, **kwargs, ) - return idata_filter @staticmethod def _validate_forecast_args( From 46149ac514e96e5da1f69b39682b1cd6d17c4956 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 26 Jul 2025 05:24:42 -0600 Subject: [PATCH 5/5] updated kalman filter outputs to use names defined in constants.py, updated sample_filter_outputs to allow sampling any filter outputs defined in constants.py --- pymc_extras/statespace/core/statespace.py | 13 +++------ .../statespace/filters/kalman_filter.py | 27 +++++++++++-------- pymc_extras/statespace/utils/constants.py | 2 ++ 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index c1ebe683c..0b0f3cd43 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1684,20 +1684,15 @@ def sample_filter_outputs( if isinstance(filter_output_names, str): filter_output_names = [filter_output_names] - drop_keys = {"predicted_observed_states", "predicted_observed_covariances"} - all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys} - if filter_output_names is None: - filter_output_names = list(all_filter_output_dims.keys()) + filter_output_names = list(FILTER_OUTPUT_DIMS.keys()) else: unknown_filter_output_names = np.setdiff1d( - filter_output_names, list(all_filter_output_dims.keys()) + filter_output_names, list(FILTER_OUTPUT_DIMS.keys()) ) if unknown_filter_output_names.size > 0: raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!") - filter_output_names = [ - x for x in all_filter_output_dims.keys() if x in filter_output_names - ] + filter_output_names = [x for x in FILTER_OUTPUT_DIMS.keys() if x in filter_output_names] compile_kwargs = kwargs.pop("compile_kwargs", {}) compile_kwargs.setdefault("mode", self.mode) @@ -1744,7 +1739,7 @@ def sample_filter_outputs( filter_outputs = filter_outputs[:-1] + list(smoother_outputs) for output in filter_outputs: if output.name in filter_output_names: - dims = all_filter_output_dims[output.name] + dims = FILTER_OUTPUT_DIMS[output.name] pm.Deterministic(output.name, output, dims=dims) with freeze_dims_and_data(m): diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index e4cbc8bed..cf9ac4aa3 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -15,10 +15,15 @@ split_vars_into_seq_and_nonseq, stabilize, ) -from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL +from pymc_extras.statespace.utils.constants import ( + FILTER_OUTPUT_NAMES, + JITTER_DEFAULT, + MATRIX_NAMES, + MISSING_FILL, +) MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) -PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] +PARAM_NAMES = MATRIX_NAMES[2:] assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " @@ -119,7 +124,7 @@ def unpack_args(self, args) -> tuple: # There are always two outputs_info wedged between the seqs and non_seqs seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :] return_ordered = [] - for name in ["c", "d", "T", "Z", "R", "H", "Q"]: + for name in PARAM_NAMES: if name in self.seq_names: idx = self.seq_names.index(name) return_ordered.append(seqs[idx]) @@ -253,28 +258,28 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: ) filtered_states = pt.specify_shape(filtered_states, (n, self.n_states)) - filtered_states.name = "filtered_states" + filtered_states.name = FILTER_OUTPUT_NAMES[0] predicted_states = pt.specify_shape(predicted_states, (n, self.n_states)) - predicted_states.name = "predicted_states" - - observed_states = pt.specify_shape(observed_states, (n, self.n_endog)) - observed_states.name = "observed_states" + predicted_states.name = FILTER_OUTPUT_NAMES[1] filtered_covariances = pt.specify_shape( filtered_covariances, (n, self.n_states, self.n_states) ) - filtered_covariances.name = "filtered_covariances" + filtered_covariances.name = FILTER_OUTPUT_NAMES[2] predicted_covariances = pt.specify_shape( predicted_covariances, (n, self.n_states, self.n_states) ) - predicted_covariances.name = "predicted_covariances" + predicted_covariances.name = FILTER_OUTPUT_NAMES[3] + + observed_states = pt.specify_shape(observed_states, (n, self.n_endog)) + observed_states.name = FILTER_OUTPUT_NAMES[4] observed_covariances = pt.specify_shape( observed_covariances, (n, self.n_endog, self.n_endog) ) - observed_covariances.name = "observed_covariances" + observed_covariances.name = FILTER_OUTPUT_NAMES[5] loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,)) loglike_obs.name = "loglike_obs" diff --git a/pymc_extras/statespace/utils/constants.py b/pymc_extras/statespace/utils/constants.py index a4ec80fb7..5df2ba4a5 100644 --- a/pymc_extras/statespace/utils/constants.py +++ b/pymc_extras/statespace/utils/constants.py @@ -42,6 +42,8 @@ "predicted_states", "filtered_covariances", "predicted_covariances", + "predicted_observed_states", + "predicted_observed_covariances", ] SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"]