Skip to content

Commit 0d4df37

Browse files
committed
added sample_filter_outputs utility and accompanying simple tests
1 parent c593ee0 commit 0d4df37

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,91 @@ def sample_statespace_matrices(
16781678

16791679
return matrix_idata
16801680

1681+
def sample_filter_outputs(
1682+
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
1683+
):
1684+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1685+
compile_kwargs.setdefault("mode", self.mode)
1686+
1687+
with pm.Model(coords=self.coords) as m:
1688+
pm_mod = modelcontext(None)
1689+
self._build_dummy_graph()
1690+
self._insert_random_variables()
1691+
1692+
if self.data_names:
1693+
for name in self.data_names:
1694+
pm.Data(**self._exog_data_info[name])
1695+
1696+
self._insert_data_variables()
1697+
1698+
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
1699+
data = self._fit_data
1700+
1701+
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
1702+
1703+
data, nan_mask = register_data_with_pymc(
1704+
data,
1705+
n_obs=self.ssm.k_endog,
1706+
obs_coords=obs_coords,
1707+
register_data=True,
1708+
)
1709+
1710+
filter_outputs = self.kalman_filter.build_graph(
1711+
data,
1712+
x0,
1713+
P0,
1714+
c,
1715+
d,
1716+
T,
1717+
Z,
1718+
R,
1719+
H,
1720+
Q,
1721+
)
1722+
1723+
smoother_outputs = self.kalman_smoother.build_graph(
1724+
T, R, Q, filter_outputs[0], filter_outputs[3]
1725+
)
1726+
1727+
all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
1728+
1729+
if filter_output_names is None:
1730+
filter_output_names = all_filter_outputs
1731+
else:
1732+
unknown_filter_output_names = np.setdiff1d(
1733+
filter_output_names, [x.name for x in all_filter_outputs]
1734+
)
1735+
if unknown_filter_output_names.size > 0:
1736+
raise ValueError(
1737+
f"{unknown_filter_output_names} not a valid filter output name!"
1738+
)
1739+
filter_output_names = [
1740+
x for x in all_filter_outputs if x.name in filter_output_names
1741+
]
1742+
1743+
for output in filter_output_names:
1744+
match output.name:
1745+
case "filtered_states" | "predicted_states" | "smoothed_states":
1746+
dims = [TIME_DIM, "state"]
1747+
case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances":
1748+
dims = [TIME_DIM, "state", "state_aux"]
1749+
case "observed_states":
1750+
dims = [TIME_DIM, "observed_state"]
1751+
case "observed_covariances":
1752+
dims = [TIME_DIM, "observed_state", "observed_state_aux"]
1753+
1754+
pm.Deterministic(output.name, output, dims=dims)
1755+
1756+
frozen_model = freeze_dims_and_data(m)
1757+
with frozen_model:
1758+
idata_filter = pm.sample_posterior_predictive(
1759+
idata if group == "posterior" else idata.prior,
1760+
var_names=[x.name for x in frozen_model.deterministics],
1761+
compile_kwargs=compile_kwargs,
1762+
**kwargs,
1763+
)
1764+
return idata_filter
1765+
16811766
@staticmethod
16821767
def _validate_forecast_args(
16831768
time_index: pd.RangeIndex | pd.DatetimeIndex,

tests/statespace/core/test_statespace.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
from collections.abc import Sequence
24
from functools import partial
35

@@ -1017,3 +1019,30 @@ def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
10171019

10181020
assert forecasts.forecast_latent.shape[2] == n_periods
10191021
assert forecasts.forecast_observed.shape[2] == n_periods
1022+
1023+
1024+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
1025+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1026+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
1027+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
1028+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
1029+
def test_sample_filter_outputs(rng, exog_ss_mod, idata_exog):
1030+
# Simple tests
1031+
idata_filter_prior = exog_ss_mod.sample_filter_outputs(
1032+
idata_exog, filter_output_names=None, group="prior"
1033+
)
1034+
1035+
specific_outputs = ["filtered_states", "filtered_covariances"]
1036+
idata_filter_specific = exog_ss_mod.sample_filter_outputs(
1037+
idata_exog, filter_output_names=specific_outputs
1038+
)
1039+
missing_outputs = np.setdiff1d(
1040+
specific_outputs, [x for x in idata_filter_specific.posterior_predictive.data_vars]
1041+
)
1042+
1043+
assert missing_outputs.size == 0
1044+
1045+
msg = "['filter_covariances' 'filter_states'] not a valid filter output name!"
1046+
incorrect_outputs = ["filter_states", "filter_covariances"]
1047+
with pytest.raises(ValueError, match=re.escape(msg)):
1048+
exog_ss_mod.sample_filter_outputs(idata_exog, filter_output_names=incorrect_outputs)

0 commit comments

Comments
 (0)