Skip to content

Commit d38c71b

Browse files
Broken test of decomposition with multiple observed
1 parent f8e7729 commit d38c71b

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

tests/statespace/models/structural/test_core.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def test_extract_components_from_idata(rng):
100100

101101
mod.build_statespace_graph(y)
102102

103-
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
104103
prior = pm.sample_prior_predictive(draws=10)
105104

106105
filter_prior = mod.sample_conditional_prior(prior)
@@ -110,3 +109,43 @@ def test_extract_components_from_idata(rng):
110109
missing = set(comp_states) - set(expected_states)
111110

112111
assert len(missing) == 0, missing
112+
113+
114+
def test_extract_multiple_observed(rng):
115+
time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)
116+
data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx)
117+
118+
y = pd.DataFrame(
119+
rng.normal(size=(100, 3)), columns=["data_1", "data_2", "data_3"], index=time_idx
120+
)
121+
122+
ll = st.LevelTrendComponent(name="trend", observed_state_names=["data_1", "data_2"])
123+
season = st.FrequencySeasonality(
124+
name="seasonal", observed_state_names=["data_1"], season_length=12, n=2, innovations=False
125+
)
126+
reg = st.RegressionComponent(
127+
state_names=["a", "b"], name="exog", observed_state_names=["data_2", "data_3"]
128+
)
129+
me = st.MeasurementError("obs", observed_state_names=["data_1", "data_2", "data_3"])
130+
mod = (ll + season + reg + me).build(verbose=True)
131+
132+
with pm.Model(coords=mod.coords) as m:
133+
data_exog = pm.Data("data_exog", data.values)
134+
135+
x0 = pm.Normal("x0", dims=["state"])
136+
P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
137+
beta_exog = pm.Normal("beta_exog", dims=["endog_exog", "state_exog"])
138+
initial_trend = pm.Normal("initial_trend", dims=["endog_trend", "state_trend"])
139+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "trend_shock"])
140+
seasonal_coefs = pm.Normal("seasonal", dims=["seasonal_state"])
141+
sigma_obs = pm.Exponential("sigma_obs", 1, dims=["endog_obs"])
142+
143+
mod.build_statespace_graph(y)
144+
145+
prior = pm.sample_prior_predictive(draws=10)
146+
147+
filter_prior = mod.sample_conditional_prior(prior)
148+
comp_prior = mod.extract_components_from_idata(filter_prior)
149+
comp_states = comp_prior.filtered_prior.coords["state"].values
150+
# expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
151+
# missing = set(comp_states) - set(expected_states)

0 commit comments

Comments
 (0)