Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 86 additions & 9 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari
states, covs = outputs[:4], outputs[4:]

state_names = [
"filtered_state",
"predicted_state",
"predicted_observed_state",
"smoothed_state",
"filtered_states",
"predicted_states",
"predicted_observed_states",
"smoothed_states",
]
cov_names = [
"filtered_covariance",
"predicted_covariance",
"predicted_observed_covariance",
"smoothed_covariance",
"filtered_covariances",
"predicted_covariances",
"predicted_observed_covariances",
"smoothed_covariances",
]

with mod:
Expand Down Expand Up @@ -939,7 +939,7 @@ def build_statespace_graph(
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)

obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"]
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"]
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None

SequenceMvNormal(
Expand Down Expand Up @@ -1678,6 +1678,83 @@ def sample_statespace_matrices(

return matrix_idata

def sample_filter_outputs(
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
):
if isinstance(filter_output_names, str):
filter_output_names = [filter_output_names]

drop_keys = {"predicted_observed_states", "predicted_observed_covariances"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't treat these as special (even though I agree it's silly to ask for them). I'd be confused if I tried to ask for them and it said it's not a valid filter output name.

Having everything in one place is convenient, even if it's duplicative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, yeah I agree with you.

I just wasn't sure because in constants.py FILTER_OUTPUT_DIMS has predicted_observed_states and predicted_observed_covariances but the output from kalman_filter.build_graph() doesn't have predicted_observed_states and predicted_observed_covariances it seems like these are named observed_states and observed_covariances.

Should I change the names in constants.py to match the returned names from kalman_filter.build_graph()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these should be consistent. But where does the name change currently happen between the filter and the idata? Maybe this is an issue for another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski, I believe this happens in _postprocess_scan_results() in kalman_filter.py. It looks like the names of the filter outputs are hardcoded in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to make this consistent in this PR I have no objection. I don't have a good sense if you should change FILTER_OUTPUT_DIMS to match the output names, or change the output names to match the FILTER_OUTPUT_DIMS. I'll defer to you if you have a sense of which one is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will do it in this PR because I think it is somewhat related. I think the names should match whatever we put in FILTER_OUTPUT_DIMS

all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys}

if filter_output_names is None:
filter_output_names = list(all_filter_output_dims.keys())
else:
unknown_filter_output_names = np.setdiff1d(
filter_output_names, list(all_filter_output_dims.keys())
)
if unknown_filter_output_names.size > 0:
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
filter_output_names = [
x for x in all_filter_output_dims.keys() if x in filter_output_names
]

compile_kwargs = kwargs.pop("compile_kwargs", {})
compile_kwargs.setdefault("mode", self.mode)

with pm.Model(coords=self.coords) as m:
self._build_dummy_graph()
self._insert_random_variables()

if self.data_names:
for name in self.data_names:
pm.Data(**self._exog_data_info[name])

self._insert_data_variables()

x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
data = self._fit_data

obs_coords = m.coords.get(OBS_STATE_DIM, None)

data, nan_mask = register_data_with_pymc(
data,
n_obs=self.ssm.k_endog,
obs_coords=obs_coords,
register_data=True,
)

filter_outputs = self.kalman_filter.build_graph(
data,
x0,
P0,
c,
d,
T,
Z,
R,
H,
Q,
)

smoother_outputs = self.kalman_smoother.build_graph(
T, R, Q, filter_outputs[0], filter_outputs[3]
)

filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
for output in filter_outputs:
if output.name in filter_output_names:
dims = all_filter_output_dims[output.name]
pm.Deterministic(output.name, output, dims=dims)

with freeze_dims_and_data(m):
return pm.sample_posterior_predictive(
idata if group == "posterior" else idata.prior,
var_names=filter_output_names,
compile_kwargs=compile_kwargs,
**kwargs,
)

@staticmethod
def _validate_forecast_args(
time_index: pd.RangeIndex | pd.DatetimeIndex,
Expand Down
28 changes: 14 additions & 14 deletions pymc_extras/statespace/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@
LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES))

FILTER_OUTPUT_NAMES = [
"filtered_state",
"predicted_state",
"filtered_covariance",
"predicted_covariance",
"filtered_states",
"predicted_states",
"filtered_covariances",
"predicted_covariances",
]

SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"]
OBSERVED_OUTPUT_NAMES = ["predicted_observed_states", "predicted_observed_covariances"]

MATRIX_DIMS = {
"x0": (ALL_STATE_DIM,),
Expand All @@ -60,14 +60,14 @@
}

FILTER_OUTPUT_DIMS = {
"filtered_state": (TIME_DIM, ALL_STATE_DIM),
"smoothed_state": (TIME_DIM, ALL_STATE_DIM),
"predicted_state": (TIME_DIM, ALL_STATE_DIM),
"filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
"predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
"filtered_states": (TIME_DIM, ALL_STATE_DIM),
"smoothed_states": (TIME_DIM, ALL_STATE_DIM),
"predicted_states": (TIME_DIM, ALL_STATE_DIM),
"filtered_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"smoothed_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"predicted_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"predicted_observed_states": (TIME_DIM, OBS_STATE_DIM),
"predicted_observed_covariances": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
}

POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]
Expand Down
43 changes: 36 additions & 7 deletions tests/statespace/core/test_statespace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

from collections.abc import Sequence
from functools import partial

Expand Down Expand Up @@ -485,16 +487,16 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill():

def test_build_statespace_graph(pymc_mod):
for name in [
"filtered_state",
"predicted_state",
"predicted_covariance",
"filtered_covariance",
"filtered_states",
"predicted_states",
"predicted_covariances",
"filtered_covariances",
]:
assert name in [x.name for x in pymc_mod.deterministics]


def test_build_smoother_graph(ss_mod, pymc_mod):
names = ["smoothed_state", "smoothed_covariance"]
names = ["smoothed_states", "smoothed_covariances"]
for name in names:
assert name in [x.name for x in pymc_mod.deterministics]

Expand Down Expand Up @@ -1191,11 +1193,11 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_

# Check that the frozen states and covariances correctly match the sliced index
np.testing.assert_allclose(
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
idata_exog.posterior["predicted_covariances"].sel(time=t0).mean(("chain", "draw")).values,
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
)
np.testing.assert_allclose(
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
idata_exog.posterior["predicted_states"].sel(time=t0).mean(("chain", "draw")).values,
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
)

Expand Down Expand Up @@ -1244,3 +1246,30 @@ def test_param_dims_coords(ss_mod_multi_component):
assert i == len(
ss_mod_multi_component.coords[s]
), f"Mismatch between shape {i} and dimension {s}"


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
def test_sample_filter_outputs(rng, exog_ss_mod, idata_exog):
# Simple tests
idata_filter_prior = exog_ss_mod.sample_filter_outputs(
idata_exog, filter_output_names=None, group="prior"
)

specific_outputs = ["filtered_states", "filtered_covariances"]
idata_filter_specific = exog_ss_mod.sample_filter_outputs(
idata_exog, filter_output_names=specific_outputs
)
missing_outputs = np.setdiff1d(
specific_outputs, [x for x in idata_filter_specific.posterior_predictive.data_vars]
)

assert missing_outputs.size == 0

msg = "['filter_covariances' 'filter_states'] not a valid filter output name!"
incorrect_outputs = ["filter_states", "filter_covariances"]
with pytest.raises(ValueError, match=re.escape(msg)):
exog_ss_mod.sample_filter_outputs(idata_exog, filter_output_names=incorrect_outputs)
2 changes: 1 addition & 1 deletion tests/statespace/models/test_SARIMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):

@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
rv = pymc_mod[f"{filter_output}_covariance"]
rv = pymc_mod[f"{filter_output}_covariances"]
cov_mats = pm.draw(rv, 100, random_seed=rng)
w, v = np.linalg.eig(cov_mats)
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")
Expand Down
2 changes: 1 addition & 1 deletion tests/statespace/models/test_VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng):

@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
rv = pymc_mod[f"{filter_output}_covariance"]
rv = pymc_mod[f"{filter_output}_covariances"]
cov_mats = pm.draw(rv, 100, random_seed=rng)
w, v = np.linalg.eig(cov_mats)
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")
Expand Down
2 changes: 1 addition & 1 deletion tests/statespace/utils/test_coord_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_filter_output_coord_assignment(f, warning, create_model):
with warning:
pymc_model = create_model(f)

for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]:
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_states"]:
assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output]


Expand Down