Skip to content

Commit bba8431

Browse files
Allow multiple observed states in measurement error component
1 parent 7cae487 commit bba8431

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

pymc_extras/statespace/models/structural/components/measurement_error.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,22 @@ def __init__(
6464
def populate_component_properties(self):
6565
self.param_names = [f"sigma_{self.name}"]
6666
self.param_dims = {}
67+
self.coords = {}
68+
69+
if self.k_endog > 1:
70+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
71+
self.coords[f"endog_{self.name}"] = self.observed_state_names
72+
6773
self.param_info = {
6874
f"sigma_{self.name}": {
69-
"shape": (),
75+
"shape": (self.k_endog,) if self.k_endog > 1 else (),
7076
"constraints": "Positive",
71-
"dims": None,
77+
"dims": (f"endog_{self.name}",) if self.k_endog > 1 else None,
7278
}
7379
}
7480

7581
def make_symbolic_graph(self) -> None:
76-
sigma_shape = ()
82+
sigma_shape = () if self.k_endog == 1 else (self.k_endog,)
7783
error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape)
7884
diag_idx = np.diag_indices(self.k_endog)
7985
idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]

tests/statespace/models/structural/components/test_measurement_error.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from pymc_extras.statespace.models import structural as st
24
from tests.statespace.models.structural.conftest import _assert_basic_coords_correct
35

@@ -8,3 +10,23 @@ def test_measurement_error(rng):
810

911
_assert_basic_coords_correct(mod)
1012
assert "sigma_obs" in mod.param_names
13+
14+
15+
def test_measurement_error_multiple_observed():
16+
mod = st.MeasurementError("obs", observed_state_names=["data_1", "data_2"])
17+
assert mod.k_endog == 2
18+
assert mod.coords["endog_obs"] == ["data_1", "data_2"]
19+
assert mod.param_dims["sigma_obs"] == ("endog_obs",)
20+
21+
22+
def test_build_with_measurement_error_subset():
23+
ll = st.LevelTrendComponent(order=2, observed_state_names=["data_1", "data_2", "data_3"])
24+
me = st.MeasurementError("obs", observed_state_names=["data_1", "data_3"])
25+
mod = (ll + me).build()
26+
27+
H = mod.ssm["obs_cov"]
28+
assert H.type.shape == (3, 3)
29+
np.testing.assert_allclose(
30+
H.eval({"sigma_obs": [1.0, 3.0]}),
31+
np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 9.0]]),
32+
)

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,8 @@ def create_structural_model_and_equivalent_statsmodel(
416416
expected_coords[AR_PARAM_DIM] += tuple(list(range(1, autoregressive + 1)))
417417
expected_coords[ALL_STATE_DIM] += ar_names
418418
expected_coords[ALL_STATE_AUX_DIM] += ar_names
419-
expected_coords[SHOCK_DIM] += ["ar_innovation"]
420-
expected_coords[SHOCK_AUX_DIM] += ["ar_innovation"]
419+
expected_coords[SHOCK_DIM] += ["data_ar_innovation"]
420+
expected_coords[SHOCK_AUX_DIM] += ["data_ar_innovation"]
421421

422422
sm_params["sigma2.ar"] = sigma2
423423
for i, rho in enumerate(ar_params):

tests/statespace/models/structural/test_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ def test_add_components():
6464
assert_allclose(all_mat, np.concatenate([ll_mat, se_mat], axis=axis), atol=ATOL, rtol=RTOL)
6565

6666

67+
def test_add_components_multiple_observed():
68+
ll = st.LevelTrendComponent(order=2, observed_state_names=["data_1", "data_2"])
69+
me = st.MeasurementError(name="obs", observed_state_names=["data_1", "data_2"])
70+
71+
mod = (ll + me).build()
72+
73+
for property in ["param_names", "shock_names", "param_info", "coords", "param_dims"]:
74+
assert [x in getattr(mod, property) for x in getattr(ll, property)]
75+
76+
6777
@pytest.mark.skipif(floatX.endswith("32"), reason="Prior covariance not PSD at half-precision")
6878
def test_extract_components_from_idata(rng):
6979
time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)

0 commit comments

Comments
 (0)