Skip to content

Commit 0b20dbc

Browse files
Add component name to shock state names
1 parent a6327b7 commit 0b20dbc

File tree

3 files changed

+5
-11
lines changed

3 files changed

+5
-11
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def populate_component_properties(self):
100100
for i in range(k_states)
101101
]
102102

103-
self.shock_names = self.observed_state_names.copy()
103+
self.shock_names = [f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names]
104104
self.param_names = [f"{self.name}_params", f"{self.name}_sigma"]
105105
self.param_dims = {f"{self.name}_params": (f"{self.name}_lag",)}
106106
self.coords = {f"{self.name}_lag": self.ar_lags.tolist()}

tests/statespace/models/structural/components/test_autoregressive.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313

1414
@pytest.mark.parametrize("order", [1, 2, [1, 0, 1]], ids=["AR1", "AR2", "AR(1,0,1)"])
1515
def test_autoregressive_model(order, rng):
16-
k = sum(order) if isinstance(order, list) else order
1716
ar = st.AutoregressiveComponent(order=order).build(verbose=False)
18-
params = {
19-
"auto_regressive_params": np.full((k,), 0.5, dtype=config.floatX),
20-
"auto_regressive_sigma": 0.1,
21-
"initial_state_cov": np.eye(k),
22-
}
2317

2418
# Check coords
2519
_assert_basic_coords_correct(ar)
@@ -47,7 +41,7 @@ def test_autoregressive_multiple_observed_build(rng):
4741
"L3[data_2]",
4842
]
4943

50-
assert mod.shock_names == ["data_1", "data_2"]
44+
assert mod.shock_names == ["auto_regressive[data_1]", "auto_regressive[data_2]"]
5145

5246
params = {
5347
"auto_regressive_params": np.full(
@@ -133,6 +127,6 @@ def test_add_autoregressive_different_observed():
133127
"L6[data_2]",
134128
]
135129

136-
assert mod.shock_names == ["data_1", "data_2"]
130+
assert mod.shock_names == ["ar1[data_1]", "ar6[data_2]"]
137131
assert mod.coords["ar1_lag"] == [1]
138132
assert mod.coords["ar6_lag"] == [1, 2, 3, 4, 5, 6]

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,8 @@ def create_structural_model_and_equivalent_statsmodel(
416416
expected_coords[AR_PARAM_DIM] += tuple(list(range(1, autoregressive + 1)))
417417
expected_coords[ALL_STATE_DIM] += ar_names
418418
expected_coords[ALL_STATE_AUX_DIM] += ar_names
419-
expected_coords[SHOCK_DIM] += ["data"]
420-
expected_coords[SHOCK_AUX_DIM] += ["data"]
419+
expected_coords[SHOCK_DIM] += ["ar"]
420+
expected_coords[SHOCK_AUX_DIM] += ["ar"]
421421

422422
sm_params["sigma2.ar"] = sigma2
423423
for i, rho in enumerate(ar_params):

0 commit comments

Comments
 (0)