Skip to content

Commit de143ad

Browse files
Update tests to respect new naming convention
1 parent 486fa14 commit de143ad

File tree

6 files changed

+37
-39
lines changed

6 files changed

+37
-39
lines changed

tests/statespace/filters/test_distributions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def pymc_model(data):
5252
data = pm.Data("data", data.values)
5353
P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
5454
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
55-
initial_trend = pm.Normal("level_trend_initial", shape=(2,))
56-
sigma_trend = pm.Exponential("level_trend_sigma", 1, shape=(2,))
55+
initial_trend = pm.Normal("initial_level_trend", shape=(2,))
56+
sigma_trend = pm.Exponential("sigma_level_trend", 1, shape=(2,))
5757

5858
return mod
5959

@@ -69,8 +69,8 @@ def pymc_model_2(data):
6969
with pm.Model(coords=coords) as mod:
7070
P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
7171
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
72-
initial_trend = pm.Normal("level_trend_initial", shape=(2,))
73-
sigma_trend = pm.Exponential("level_trend_sigma", 1, shape=(2,))
72+
initial_trend = pm.Normal("initial_level_trend", shape=(2,))
73+
sigma_trend = pm.Exponential("sigma_level_trend", 1, shape=(2,))
7474
sigma_me = pm.Exponential("sigma_error", 1)
7575

7676
return mod
@@ -207,8 +207,8 @@ def test_lgss_with_time_varying_inputs(output_name, rng):
207207
exog_data = pm.Data("data_exog", X)
208208
P0_diag = pm.Exponential("P0_diag", 1, shape=(mod.k_states,))
209209
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
210-
initial_trend = pm.Normal("level_trend_initial", shape=(2,))
211-
sigma_trend = pm.Exponential("level_trend_sigma", 1, shape=(2,))
210+
initial_trend = pm.Normal("initial_level_trend", shape=(2,))
211+
sigma_trend = pm.Exponential("sigma_level_trend", 1, shape=(2,))
212212
beta_exog = pm.Normal("beta_exog", shape=(3,))
213213

214214
mod._insert_random_variables()

tests/statespace/models/structural/components/test_level_trend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
def test_level_trend_model(rng):
1616
mod = st.LevelTrendComponent(order=2, innovations_order=0)
17-
params = {"level_trend_initial": [0.0, 1.0]}
17+
params = {"initial_level_trend": [0.0, 1.0]}
1818
x, y = simulate_from_numpy_model(mod, rng, params)
1919

2020
assert_allclose(np.diff(y), 1, atol=ATOL, rtol=RTOL)
2121

2222
# Check coords
2323
mod = mod.build(verbose=False)
2424
_assert_basic_coords_correct(mod)
25-
assert mod.coords["level_trend_state"] == ["level", "trend"]
25+
assert mod.coords["state_level_trend"] == ["level", "trend"]
2626

2727

2828
def test_level_trend_multiple_observed_construction():
@@ -34,8 +34,8 @@ def test_level_trend_multiple_observed_construction():
3434
assert mod.k_states == 6
3535
assert mod.k_posdef == 3
3636

37-
assert mod.coords["level_trend_state"] == ["level", "trend"]
38-
assert mod.coords["level_trend_endog"] == ["data_1", "data_2", "data_3"]
37+
assert mod.coords["state_level_trend"] == ["level", "trend"]
38+
assert mod.coords["endog_level_trend"] == ["data_1", "data_2", "data_3"]
3939

4040
assert mod.state_names == [
4141
"level[data_1]",
@@ -95,7 +95,7 @@ def test_level_trend_multiple_observed(rng):
9595
mod = st.LevelTrendComponent(
9696
order=2, innovations_order=0, observed_state_names=["data_1", "data_2", "data_3"]
9797
)
98-
params = {"level_trend_initial": np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]])}
98+
params = {"initial_level_trend": np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]])}
9999

100100
x, y = simulate_from_numpy_model(mod, rng, params)
101101
assert (np.diff(y, axis=0) == np.array([[1.0, 2.0, 3.0]])).all().all()
@@ -115,8 +115,8 @@ def test_add_level_trend_with_different_observed():
115115
assert mod.k_states == 3
116116
assert mod.k_posdef == 2
117117

118-
assert mod.coords["ll_state"] == ["level", "trend"]
119-
assert mod.coords["grw_state"] == ["level"]
118+
assert mod.coords["state_ll"] == ["level", "trend"]
119+
assert mod.coords["state_grw"] == ["level"]
120120

121121
assert mod.state_names == ["level[data_1]", "trend[data_1]", "level[data_2]"]
122122
assert mod.shock_names == ["trend[data_1]", "level[data_2]"]

tests/statespace/models/structural/components/test_regression.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_exogenous_component(rng):
2828

2929
mod = mod.build(verbose=False)
3030
_assert_basic_coords_correct(mod)
31-
assert mod.coords["exog_state"] == ["feature_1", "feature_2"]
31+
assert mod.coords["state_exog"] == ["feature_1", "feature_2"]
3232

3333

3434
def test_adding_exogenous_component(rng):
@@ -65,11 +65,9 @@ def test_regression_with_multiple_observed_states(rng):
6565
assert_allclose(y[:, 1], data @ params["beta_exog"][1], atol=ATOL, rtol=RTOL)
6666

6767
mod = mod.build(verbose=False)
68-
assert mod.coords["exog_state"] == [
69-
"feature_1[data_1]",
70-
"feature_2[data_1]",
71-
"feature_1[data_2]",
72-
"feature_2[data_2]",
68+
assert mod.coords["state_exog"] == [
69+
"feature_1",
70+
"feature_2",
7371
]
7472

7573
Z = mod.ssm["design"].eval({"data_exog": data})
@@ -90,8 +88,8 @@ def test_add_regression_components_with_multiple_observed_states(rng):
9088
reg2 = st.RegressionComponent(state_names=["c"], name="exog2", observed_state_names=["data_3"])
9189

9290
mod = (reg1 + reg2).build(verbose=False)
93-
assert mod.coords["exog1_state"] == ["a[data_1]", "b[data_1]", "a[data_2]", "b[data_2]"]
94-
assert mod.coords["exog2_state"] == ["c[data_3]"]
91+
assert mod.coords["state_exog1"] == ["a", "b"]
92+
assert mod.coords["state_exog2"] == ["c"]
9593

9694
Z = mod.ssm["design"].eval({"data_exog1": data_1, "data_exog2": data_2})
9795
vec_block_diag = np.vectorize(block_diag, signature="(n,m),(o,p)->(q,r)")
@@ -124,7 +122,7 @@ def test_filter_scans_time_varying_design_matrix(rng):
124122

125123
x0 = pm.Normal("x0", dims=["state"])
126124
P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
127-
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
125+
beta_exog = pm.Normal("beta_exog", dims=["state_exog"])
128126

129127
mod.build_statespace_graph(y)
130128
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def create_structural_model_and_equivalent_statsmodel(
220220

221221
if level:
222222
level_trend_order[0] = 1
223-
expected_coords["level_state"] += [
223+
expected_coords["state_level"] += [
224224
"level",
225225
]
226226
expected_coords[ALL_STATE_DIM] += [
@@ -241,7 +241,7 @@ def create_structural_model_and_equivalent_statsmodel(
241241

242242
if trend:
243243
level_trend_order[1] = 1
244-
expected_coords["level_state"] += [
244+
expected_coords["state_level"] += [
245245
"trend",
246246
]
247247
expected_coords[ALL_STATE_DIM] += [
@@ -258,7 +258,7 @@ def create_structural_model_and_equivalent_statsmodel(
258258
expected_coords[SHOCK_AUX_DIM] += ["trend"]
259259

260260
if level or trend:
261-
expected_param_dims["level_initial"] += ("level_state",)
261+
expected_param_dims["initial_level"] += ("state_level",)
262262
level_value = np.where(
263263
level_trend_order,
264264
rng.normal(
@@ -272,13 +272,13 @@ def create_structural_model_and_equivalent_statsmodel(
272272
max_order = np.flatnonzero(level_value)[-1].item() + 1
273273
level_trend_order = level_trend_order[:max_order]
274274

275-
params["level_initial"] = level_value[:max_order]
275+
params["initial_level"] = level_value[:max_order]
276276
sm_init["level"] = level_value[0]
277277
sm_init["trend"] = level_value[1]
278278

279279
if sum(level_trend_innov_order) > 0:
280-
expected_param_dims["level_sigma"] += ("level_shock",)
281-
params["level_sigma"] = np.sqrt(sigma_level_value2)
280+
expected_param_dims["sigma_level"] += ("level_shock",)
281+
params["sigma_level"] = np.sqrt(sigma_level_value2)
282282

283283
sigma_level_value = sigma_level_value2.tolist()
284284
if stochastic_level:

tests/statespace/models/structural/test_core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def test_add_components():
2222
mod = ll + se
2323

2424
ll_params = {
25-
"level_trend_initial": np.zeros(2, dtype=floatX),
26-
"level_trend_sigma": np.ones(2, dtype=floatX),
25+
"initial_level_trend": np.zeros(2, dtype=floatX),
26+
"sigma_level_trend": np.ones(2, dtype=floatX),
2727
}
2828
se_params = {
2929
"seasonal_coefs": np.ones(11, dtype=floatX),
@@ -92,9 +92,9 @@ def test_extract_components_from_idata(rng):
9292

9393
x0 = pm.Normal("x0", dims=["state"])
9494
P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
95-
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
96-
initial_trend = pm.Normal("level_trend_initial", dims=["level_trend_state"])
97-
sigma_trend = pm.Exponential("level_trend_sigma", 1, dims=["level_trend_shock"])
95+
beta_exog = pm.Normal("beta_exog", dims=["state_exog"])
96+
initial_trend = pm.Normal("initial_level_trend", dims=["state_level_trend"])
97+
sigma_trend = pm.Exponential("sigma_level_trend", 1, dims=["level_trend_shock"])
9898
seasonal_coefs = pm.Normal("seasonal", dims=["seasonal_state"])
9999
sigma_obs = pm.Exponential("sigma_obs", 1)
100100

tests/statespace/utils/test_coord_assignment.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def _create_model(f):
8080
dims="state",
8181
)
8282
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=("state", "state_aux"))
83-
initial_trend = pm.Normal("level_trend_initial", dims="level_trend_state")
84-
sigma_trend = pm.Exponential("level_trend_sigma", 1, dims="level_trend_shock")
83+
initial_trend = pm.Normal("initial_level_trend", dims="state_level_trend")
84+
sigma_trend = pm.Exponential("sigma_level_trend", 1, dims="level_trend_shock")
8585
ss_mod.build_statespace_graph(data, save_kalman_filter_outputs_in_idata=True)
8686
return mod
8787

@@ -103,8 +103,8 @@ def test_model_build_without_coords(load_dataset):
103103
with pm.Model() as mod:
104104
P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
105105
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
106-
initial_trend = pm.Normal("level_trend_initial", shape=(2,))
107-
sigma_trend = pm.Exponential("level_trend_sigma", 1, shape=(2,))
106+
initial_trend = pm.Normal("initial_level_trend", shape=(2,))
107+
sigma_trend = pm.Exponential("sigma_level_trend", 1, shape=(2,))
108108
ss_mod.build_statespace_graph(data, register_data=False)
109109

110110
assert mod.coords == {}
@@ -131,8 +131,8 @@ def make_model(index):
131131
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
132132
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
133133

134-
initial_trend = pm.Normal("level_trend_initial", dims=initial_trend_dims)
135-
sigma_trend = pm.Gamma("level_trend_sigma", alpha=2, beta=50, dims=sigma_trend_dims)
134+
initial_trend = pm.Normal("initial_level_trend", dims=initial_trend_dims)
135+
sigma_trend = pm.Gamma("sigma_level_trend", alpha=2, beta=50, dims=sigma_trend_dims)
136136

137137
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
138138
ss_mod.build_statespace_graph(

0 commit comments

Comments
 (0)