Skip to content

Commit d142a91

Browse files
committed
updated plurality for some of the constants in constants.py
1 parent 5b064d4 commit d142a91

File tree

6 files changed

+38
-38
lines changed

6 files changed

+38
-38
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari
805805
states, covs = outputs[:4], outputs[4:]
806806

807807
state_names = [
808-
"filtered_state",
809-
"predicted_state",
810-
"predicted_observed_state",
811-
"smoothed_state",
808+
"filtered_states",
809+
"predicted_states",
810+
"predicted_observed_states",
811+
"smoothed_states",
812812
]
813813
cov_names = [
814-
"filtered_covariance",
815-
"predicted_covariance",
816-
"predicted_observed_covariance",
817-
"smoothed_covariance",
814+
"filtered_covariances",
815+
"predicted_covariances",
816+
"predicted_observed_covariances",
817+
"smoothed_covariances",
818818
]
819819

820820
with mod:
@@ -939,7 +939,7 @@ def build_statespace_graph(
939939
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
940940
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
941941

942-
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"]
942+
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"]
943943
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None
944944

945945
SequenceMvNormal(
@@ -1727,14 +1727,14 @@ def sample_filter_outputs(
17271727
)
17281728

17291729
# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
1730-
filter_output_dims_mapping = {}
1731-
for k in FILTER_OUTPUT_DIMS.keys():
1732-
filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]
1730+
# filter_output_dims_mapping = {}
1731+
# for k in FILTER_OUTPUT_DIMS.keys():
1732+
# filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]
17331733

17341734
all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
17351735
# This excludes observed states and observed covariances from the filter outputs
17361736
all_filter_outputs = [
1737-
output for output in all_filter_outputs if output.name in filter_output_dims_mapping
1737+
output for output in all_filter_outputs if output.name in FILTER_OUTPUT_DIMS
17381738
]
17391739

17401740
if filter_output_names is None:
@@ -1752,7 +1752,7 @@ def sample_filter_outputs(
17521752
]
17531753

17541754
for output in filter_output_names:
1755-
dims = filter_output_dims_mapping[output.name]
1755+
dims = FILTER_OUTPUT_DIMS[output.name]
17561756
pm.Deterministic(output.name, output, dims=dims)
17571757

17581758
frozen_model = freeze_dims_and_data(m)

pymc_extras/statespace/utils/constants.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838
LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES))
3939

4040
FILTER_OUTPUT_NAMES = [
41-
"filtered_state",
42-
"predicted_state",
43-
"filtered_covariance",
44-
"predicted_covariance",
41+
"filtered_states",
42+
"predicted_states",
43+
"filtered_covariances",
44+
"predicted_covariances",
4545
]
4646

47-
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
48-
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
47+
SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"]
48+
OBSERVED_OUTPUT_NAMES = ["predicted_observed_states", "predicted_observed_covariances"]
4949

5050
MATRIX_DIMS = {
5151
"x0": (ALL_STATE_DIM,),
@@ -60,14 +60,14 @@
6060
}
6161

6262
FILTER_OUTPUT_DIMS = {
63-
"filtered_state": (TIME_DIM, ALL_STATE_DIM),
64-
"smoothed_state": (TIME_DIM, ALL_STATE_DIM),
65-
"predicted_state": (TIME_DIM, ALL_STATE_DIM),
66-
"filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
67-
"smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
68-
"predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
69-
"predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
70-
"predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
63+
"filtered_states": (TIME_DIM, ALL_STATE_DIM),
64+
"smoothed_states": (TIME_DIM, ALL_STATE_DIM),
65+
"predicted_states": (TIME_DIM, ALL_STATE_DIM),
66+
"filtered_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
67+
"smoothed_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
68+
"predicted_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
69+
"predicted_observed_states": (TIME_DIM, OBS_STATE_DIM),
70+
"predicted_observed_covariances": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
7171
}
7272

7373
POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]

tests/statespace/core/test_statespace.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,16 +487,16 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill():
487487

488488
def test_build_statespace_graph(pymc_mod):
489489
for name in [
490-
"filtered_state",
491-
"predicted_state",
492-
"predicted_covariance",
493-
"filtered_covariance",
490+
"filtered_states",
491+
"predicted_states",
492+
"predicted_covariances",
493+
"filtered_covariances",
494494
]:
495495
assert name in [x.name for x in pymc_mod.deterministics]
496496

497497

498498
def test_build_smoother_graph(ss_mod, pymc_mod):
499-
names = ["smoothed_state", "smoothed_covariance"]
499+
names = ["smoothed_states", "smoothed_covariances"]
500500
for name in names:
501501
assert name in [x.name for x in pymc_mod.deterministics]
502502

@@ -1193,11 +1193,11 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_
11931193

11941194
# Check that the frozen states and covariances correctly match the sliced index
11951195
np.testing.assert_allclose(
1196-
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
1196+
idata_exog.posterior["predicted_covariances"].sel(time=t0).mean(("chain", "draw")).values,
11971197
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
11981198
)
11991199
np.testing.assert_allclose(
1200-
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
1200+
idata_exog.posterior["predicted_states"].sel(time=t0).mean(("chain", "draw")).values,
12011201
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
12021202
)
12031203

tests/statespace/models/test_SARIMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
321321

322322
@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
323323
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
324-
rv = pymc_mod[f"{filter_output}_covariance"]
324+
rv = pymc_mod[f"{filter_output}_covariances"]
325325
cov_mats = pm.draw(rv, 100, random_seed=rng)
326326
w, v = np.linalg.eig(cov_mats)
327327
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")

tests/statespace/models/test_VARMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng):
156156

157157
@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
158158
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
159-
rv = pymc_mod[f"{filter_output}_covariance"]
159+
rv = pymc_mod[f"{filter_output}_covariances"]
160160
cov_mats = pm.draw(rv, 100, random_seed=rng)
161161
w, v = np.linalg.eig(cov_mats)
162162
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")

tests/statespace/utils/test_coord_assignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_filter_output_coord_assignment(f, warning, create_model):
9393
with warning:
9494
pymc_model = create_model(f)
9595

96-
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]:
96+
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_states"]:
9797
assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output]
9898

9999

0 commit comments

Comments
 (0)