Skip to content

Commit b932255

Browse files
Update tests to new names
1 parent 3c5124d commit b932255

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

tests/statespace/core/test_statespace.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
167167
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
168168
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
169169

170-
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
170+
initial_trend = pm.Normal(
171+
"level_trend_initial", mu=[0], sigma=[0.005], dims=["level_trend_state"]
172+
)
171173

172174
data_exog = pm.Data(
173175
"data_exog", exog_data["x1"].values[:, None], dims=["time", "exog_state"]
@@ -184,12 +186,12 @@ def pymc_mod_no_exog(ss_mod_no_exog, rng):
184186
y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"])
185187

186188
with pm.Model(coords=ss_mod_no_exog.coords) as m:
187-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
189+
initial_trend = pm.Normal("level_trend_initial", dims=["level_trend_state"])
188190
P0_sigma = pm.Exponential("P0_sigma", 1)
189191
P0 = pm.Deterministic(
190192
"P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"]
191193
)
192-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
194+
sigma_trend = pm.Exponential("level_trend_sigma", 1, dims=["level_trend_shock"])
193195
ss_mod_no_exog.build_statespace_graph(y)
194196

195197
return m
@@ -204,12 +206,12 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
204206
)
205207

206208
with pm.Model(coords=ss_mod_no_exog_dt.coords) as m:
207-
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
209+
initial_trend = pm.Normal("level_trend_initial", dims=["level_trend_state"])
208210
P0_sigma = pm.Exponential("P0_sigma", 1)
209211
P0 = pm.Deterministic(
210212
"P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"]
211213
)
212-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
214+
sigma_trend = pm.Exponential("level_trend_sigma", 1, dims=["level_trend_shock"])
213215
ss_mod_no_exog_dt.build_statespace_graph(y)
214216

215217
return m
@@ -313,7 +315,7 @@ def test_build_statespace_graph_warns_if_data_has_nans():
313315
ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
314316

315317
with pm.Model() as pymc_mod:
316-
initial_trend = pm.Normal("initial_trend", shape=(1,))
318+
initial_trend = pm.Normal("level_trend_initial", shape=(1,))
317319
P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
318320
with pytest.warns(pm.ImputationWarning):
319321
ss_mod.build_statespace_graph(
@@ -326,7 +328,7 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill():
326328
ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False)
327329

328330
with pm.Model() as pymc_mod:
329-
initial_trend = pm.Normal("initial_trend", shape=(1,))
331+
initial_trend = pm.Normal("level_trend_initial", shape=(1,))
330332
P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX))
331333
with pytest.raises(ValueError, match="Provided data contains the value 1.0"):
332334
data = np.ones((10, 1), dtype=floatX)

tests/statespace/utils/test_coord_assignment.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def _create_model(f):
8080
dims="state",
8181
)
8282
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=("state", "state_aux"))
83-
initial_trend = pm.Normal("initial_trend", dims="trend_state")
84-
sigma_trend = pm.Exponential("sigma_trend", 1, dims="trend_shock")
83+
initial_trend = pm.Normal("level_trend_initial", dims="level_trend_state")
84+
sigma_trend = pm.Exponential("level_trend_sigma", 1, dims="level_trend_shock")
8585
ss_mod.build_statespace_graph(data, save_kalman_filter_outputs_in_idata=True)
8686
return mod
8787

@@ -103,8 +103,8 @@ def test_model_build_without_coords(load_dataset):
103103
with pm.Model() as mod:
104104
P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
105105
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
106-
initial_trend = pm.Normal("initial_trend", shape=(2,))
107-
sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
106+
initial_trend = pm.Normal("level_trend_initial", shape=(2,))
107+
sigma_trend = pm.Exponential("level_trend_sigma", 1, shape=(2,))
108108
ss_mod.build_statespace_graph(data, register_data=False)
109109

110110
assert mod.coords == {}
@@ -131,8 +131,8 @@ def make_model(index):
131131
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
132132
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
133133

134-
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
135-
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
134+
initial_trend = pm.Normal("level_trend_initial", dims=initial_trend_dims)
135+
sigma_trend = pm.Gamma("level_trend_sigma", alpha=2, beta=50, dims=sigma_trend_dims)
136136

137137
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
138138
ss_mod.build_statespace_graph(

0 commit comments

Comments
 (0)