Skip to content

Commit 503eec5

Browse files
fix decompose test
1 parent ce14343 commit 503eec5

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tests/statespace/models/structural/test_core.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_extract_multiple_observed(rng):
126126
reg = st.RegressionComponent(
127127
state_names=["a", "b"], name="exog", observed_state_names=["data_2", "data_3"]
128128
)
129-
me = st.MeasurementError("obs", observed_state_names=["data_1", "data_2", "data_3"])
129+
me = st.MeasurementError("obs", observed_state_names=["data_1", "data_3"])
130130
mod = (ll + season + reg + me).build(verbose=True)
131131

132132
with pm.Model(coords=mod.coords) as m:
@@ -147,5 +147,18 @@ def test_extract_multiple_observed(rng):
147147
filter_prior = mod.sample_conditional_prior(prior)
148148
comp_prior = mod.extract_components_from_idata(filter_prior)
149149
comp_states = comp_prior.filtered_prior.coords["state"].values
150-
# expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
151-
# missing = set(comp_states) - set(expected_states)
150+
151+
expected_states = [
152+
"trend[level[data_1]]",
153+
"trend[trend[data_1]]",
154+
"trend[level[data_2]]",
155+
"trend[trend[data_2]]",
156+
"seasonal",
157+
"exog[a[data_2]]",
158+
"exog[b[data_2]]",
159+
"exog[a[data_3]]",
160+
"exog[b[data_3]]",
161+
]
162+
163+
missing = set(comp_states) - set(expected_states)
164+
assert len(missing) == 0, missing

0 commit comments

Comments
 (0)