diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 96e1e9b52..0b0f3cd43 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari states, covs = outputs[:4], outputs[4:] state_names = [ - "filtered_state", - "predicted_state", - "predicted_observed_state", - "smoothed_state", + "filtered_states", + "predicted_states", + "predicted_observed_states", + "smoothed_states", ] cov_names = [ - "filtered_covariance", - "predicted_covariance", - "predicted_observed_covariance", - "smoothed_covariance", + "filtered_covariances", + "predicted_covariances", + "predicted_observed_covariances", + "smoothed_covariances", ] with mod: @@ -939,7 +939,7 @@ def build_statespace_graph( all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances] self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) - obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"] + obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None SequenceMvNormal( @@ -1678,6 +1678,78 @@ def sample_statespace_matrices( return matrix_idata + def sample_filter_outputs( + self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs + ): + if isinstance(filter_output_names, str): + filter_output_names = [filter_output_names] + + if filter_output_names is None: + filter_output_names = list(FILTER_OUTPUT_DIMS.keys()) + else: + unknown_filter_output_names = np.setdiff1d( + filter_output_names, list(FILTER_OUTPUT_DIMS.keys()) + ) + if unknown_filter_output_names.size > 0: + raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!") + filter_output_names = [x for x in FILTER_OUTPUT_DIMS.keys() if x in filter_output_names] + + compile_kwargs = kwargs.pop("compile_kwargs", {}) + compile_kwargs.setdefault("mode", self.mode) + + with pm.Model(coords=self.coords) as m: + self._build_dummy_graph() + self._insert_random_variables() + + if self.data_names: + for name in self.data_names: + pm.Data(**self._exog_data_info[name]) + + self._insert_data_variables() + + x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() + data = self._fit_data + + obs_coords = m.coords.get(OBS_STATE_DIM, None) + + data, nan_mask = register_data_with_pymc( + data, + n_obs=self.ssm.k_endog, + obs_coords=obs_coords, + register_data=True, + ) + + filter_outputs = self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ) + + smoother_outputs = self.kalman_smoother.build_graph( + T, R, Q, filter_outputs[0], filter_outputs[3] + ) + + filter_outputs = filter_outputs[:-1] + list(smoother_outputs) + for output in filter_outputs: + if output.name in filter_output_names: + dims = FILTER_OUTPUT_DIMS[output.name] + pm.Deterministic(output.name, output, dims=dims) + + with freeze_dims_and_data(m): + return pm.sample_posterior_predictive( + idata if group == "posterior" else idata.prior, + var_names=filter_output_names, + compile_kwargs=compile_kwargs, + **kwargs, + ) + @staticmethod def _validate_forecast_args( time_index: pd.RangeIndex | pd.DatetimeIndex, diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index e4cbc8bed..cf9ac4aa3 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -15,10 +15,15 @@ split_vars_into_seq_and_nonseq, stabilize, ) -from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL +from pymc_extras.statespace.utils.constants import ( + FILTER_OUTPUT_NAMES, + JITTER_DEFAULT, + MATRIX_NAMES, + MISSING_FILL, +) MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) -PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] +PARAM_NAMES = MATRIX_NAMES[2:] assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " @@ -119,7 +124,7 @@ def unpack_args(self, args) -> tuple: # There are always two outputs_info wedged between the seqs and non_seqs seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :] return_ordered = [] - for name in ["c", "d", "T", "Z", "R", "H", "Q"]: + for name in PARAM_NAMES: if name in self.seq_names: idx = self.seq_names.index(name) return_ordered.append(seqs[idx]) @@ -253,28 +258,28 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: ) filtered_states = pt.specify_shape(filtered_states, (n, self.n_states)) - filtered_states.name = "filtered_states" + filtered_states.name = FILTER_OUTPUT_NAMES[0] predicted_states = pt.specify_shape(predicted_states, (n, self.n_states)) - predicted_states.name = "predicted_states" - - observed_states = pt.specify_shape(observed_states, (n, self.n_endog)) - observed_states.name = "observed_states" + predicted_states.name = FILTER_OUTPUT_NAMES[1] filtered_covariances = pt.specify_shape( filtered_covariances, (n, self.n_states, self.n_states) ) - filtered_covariances.name = "filtered_covariances" + filtered_covariances.name = FILTER_OUTPUT_NAMES[2] predicted_covariances = pt.specify_shape( predicted_covariances, (n, self.n_states, self.n_states) ) - predicted_covariances.name = "predicted_covariances" + predicted_covariances.name = FILTER_OUTPUT_NAMES[3] + + observed_states = pt.specify_shape(observed_states, (n, self.n_endog)) + observed_states.name = FILTER_OUTPUT_NAMES[4] observed_covariances = pt.specify_shape( observed_covariances, (n, self.n_endog, self.n_endog) ) - observed_covariances.name = "observed_covariances" + observed_covariances.name = FILTER_OUTPUT_NAMES[5] loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,)) loglike_obs.name = "loglike_obs" diff --git a/pymc_extras/statespace/utils/constants.py b/pymc_extras/statespace/utils/constants.py index 3e9a3958b..5df2ba4a5 100644 --- a/pymc_extras/statespace/utils/constants.py +++ b/pymc_extras/statespace/utils/constants.py @@ -38,14 +38,16 @@ LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES)) FILTER_OUTPUT_NAMES = [ - "filtered_state", - "predicted_state", - "filtered_covariance", - "predicted_covariance", + "filtered_states", + "predicted_states", + "filtered_covariances", + "predicted_covariances", + "predicted_observed_states", + "predicted_observed_covariances", ] -SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"] -OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"] +SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"] +OBSERVED_OUTPUT_NAMES = ["predicted_observed_states", "predicted_observed_covariances"] MATRIX_DIMS = { "x0": (ALL_STATE_DIM,), @@ -60,14 +62,14 @@ } FILTER_OUTPUT_DIMS = { - "filtered_state": (TIME_DIM, ALL_STATE_DIM), - "smoothed_state": (TIME_DIM, ALL_STATE_DIM), - "predicted_state": (TIME_DIM, ALL_STATE_DIM), - "filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), - "smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), - "predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), - "predicted_observed_state": (TIME_DIM, OBS_STATE_DIM), - "predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM), + "filtered_states": (TIME_DIM, ALL_STATE_DIM), + "smoothed_states": (TIME_DIM, ALL_STATE_DIM), + "predicted_states": (TIME_DIM, ALL_STATE_DIM), + "filtered_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), + "smoothed_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), + "predicted_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM), + "predicted_observed_states": (TIME_DIM, OBS_STATE_DIM), + "predicted_observed_covariances": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM), } POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"] diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 63b076bd9..06aec484e 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -1,3 +1,5 @@ +import re + from collections.abc import Sequence from functools import partial @@ -485,16 +487,16 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill(): def test_build_statespace_graph(pymc_mod): for name in [ - "filtered_state", - "predicted_state", - "predicted_covariance", - "filtered_covariance", + "filtered_states", + "predicted_states", + "predicted_covariances", + "filtered_covariances", ]: assert name in [x.name for x in pymc_mod.deterministics] def test_build_smoother_graph(ss_mod, pymc_mod): - names = ["smoothed_state", "smoothed_covariance"] + names = ["smoothed_states", "smoothed_covariances"] for name in names: assert name in [x.name for x in pymc_mod.deterministics] @@ -1191,11 +1193,11 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_ # Check that the frozen states and covariances correctly match the sliced index np.testing.assert_allclose( - idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values, + idata_exog.posterior["predicted_covariances"].sel(time=t0).mean(("chain", "draw")).values, idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values, ) np.testing.assert_allclose( - idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values, + idata_exog.posterior["predicted_states"].sel(time=t0).mean(("chain", "draw")).values, idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values, ) @@ -1244,3 +1246,30 @@ def test_param_dims_coords(ss_mod_multi_component): assert i == len( ss_mod_multi_component.coords[s] ), f"Mismatch between shape {i} and dimension {s}" + + +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") +@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") +def test_sample_filter_outputs(rng, exog_ss_mod, idata_exog): + # Simple tests + idata_filter_prior = exog_ss_mod.sample_filter_outputs( + idata_exog, filter_output_names=None, group="prior" + ) + + specific_outputs = ["filtered_states", "filtered_covariances"] + idata_filter_specific = exog_ss_mod.sample_filter_outputs( + idata_exog, filter_output_names=specific_outputs + ) + missing_outputs = np.setdiff1d( + specific_outputs, [x for x in idata_filter_specific.posterior_predictive.data_vars] + ) + + assert missing_outputs.size == 0 + + msg = "['filter_covariances' 'filter_states'] not a valid filter output name!" + incorrect_outputs = ["filter_states", "filter_covariances"] + with pytest.raises(ValueError, match=re.escape(msg)): + exog_ss_mod.sample_filter_outputs(idata_exog, filter_output_names=incorrect_outputs) diff --git a/tests/statespace/models/test_SARIMAX.py b/tests/statespace/models/test_SARIMAX.py index 693b367d8..d04303b2f 100644 --- a/tests/statespace/models/test_SARIMAX.py +++ b/tests/statespace/models/test_SARIMAX.py @@ -321,7 +321,7 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng): @pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"]) def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): - rv = pymc_mod[f"{filter_output}_covariance"] + rv = pymc_mod[f"{filter_output}_covariances"] cov_mats = pm.draw(rv, 100, random_seed=rng) w, v = np.linalg.eig(cov_mats) assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}") diff --git a/tests/statespace/models/test_VARMAX.py b/tests/statespace/models/test_VARMAX.py index 31f112772..fbd0cfc04 100644 --- a/tests/statespace/models/test_VARMAX.py +++ b/tests/statespace/models/test_VARMAX.py @@ -156,7 +156,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng): @pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"]) def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): - rv = pymc_mod[f"{filter_output}_covariance"] + rv = pymc_mod[f"{filter_output}_covariances"] cov_mats = pm.draw(rv, 100, random_seed=rng) w, v = np.linalg.eig(cov_mats) assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}") diff --git a/tests/statespace/utils/test_coord_assignment.py b/tests/statespace/utils/test_coord_assignment.py index 22f98e921..4f27d4250 100644 --- a/tests/statespace/utils/test_coord_assignment.py +++ b/tests/statespace/utils/test_coord_assignment.py @@ -93,7 +93,7 @@ def test_filter_output_coord_assignment(f, warning, create_model): with warning: pymc_model = create_model(f) - for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]: + for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_states"]: assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output]