Skip to content

Commit b2ab6f8

Browse files
committed
Fix coefs_ to params_ renaming in broader tests
1 parent 49ff101 commit b2ab6f8

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed

pymc_extras/statespace/models/structural/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class StructuralTimeSeries(PyMCStateSpace):
120120
initial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
121121
sigma_trend = pm.HalfNormal('sigma_trend', sigma=1, dims=ss_mod.param_dims['sigma_trend'])
122122
123-
seasonal_coefs = pm.Normal('seasonal_coefs', sigma=1, dims=ss_mod.param_dims['seasonal_coefs'])
123+
seasonal_coefs = pm.Normal('params_seasonal', sigma=1, dims=ss_mod.param_dims['params_seasonal'])
124124
sigma_seasonal = pm.HalfNormal('sigma_seasonal', sigma=1)
125125
126126
sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs'])

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _assert_coord_shapes_match_matrices(mod, params):
6767
n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
6868
n_obs = len(mod.coords[OBS_STATE_DIM])
6969

70+
print(f"{mod.coords[ALL_STATE_DIM] = }")
71+
print(f"{mod.coords[SHOCK_DIM] = }")
72+
print(f"{mod.coords[OBS_STATE_DIM] = }")
73+
print(f"{R = }")
74+
7075
assert x0.shape[-1:] == (
7176
n_states,
7277
), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
@@ -104,12 +109,12 @@ def _assert_keys_match(test_dict, expected_dict):
104109
expected_keys = list(expected_dict.keys())
105110
param_keys = list(test_dict.keys())
106111
key_diff = set(expected_keys) - set(param_keys)
107-
assert len(key_diff) == 0, f'{", ".join(key_diff)} were not found in the test_dict keys.'
112+
assert len(key_diff) == 0, f"{', '.join(key_diff)} were not found in the test_dict keys."
108113

109114
key_diff = set(param_keys) - set(expected_keys)
110115
assert (
111116
len(key_diff) == 0
112-
), f'{", ".join(key_diff)} were keys of the tests_dict not in expected_dict.'
117+
), f"{', '.join(key_diff)} were keys of the tests_dict not in expected_dict."
113118

114119

115120
def _assert_param_dims_correct(param_dims, expected_dims):
@@ -296,8 +301,8 @@ def create_structural_model_and_equivalent_statsmodel(
296301
if seasonal is not None:
297302
state_names = [f"seasonal_{i}" for i in range(seasonal)][1:]
298303
seasonal_coefs = rng.normal(size=(seasonal - 1,)).astype(floatX)
299-
params["coefs_seasonal"] = seasonal_coefs
300-
expected_param_dims["coefs_seasonal"] += ("state_seasonal",)
304+
params["params_seasonal"] = seasonal_coefs
305+
expected_param_dims["params_seasonal"] += ("state_seasonal",)
301306

302307
expected_coords["state_seasonal"] += tuple(state_names)
303308
expected_coords[ALL_STATE_DIM] += state_names
@@ -335,8 +340,8 @@ def create_structural_model_and_equivalent_statsmodel(
335340

336341
seasonal_params = rng.normal(size=n_states).astype(floatX)
337342

338-
params[f"seasonal_{s}"] = seasonal_params
339-
expected_param_dims[f"seasonal_{s}"] += (f"state_seasonal_{s}",)
343+
params[f"params_seasonal_{s}"] = seasonal_params
344+
expected_param_dims[f"params_seasonal_{s}"] += (f"state_seasonal_{s}",)
340345
expected_coords[ALL_STATE_DIM] += state_names
341346
expected_coords[ALL_STATE_AUX_DIM] += state_names
342347
expected_coords[f"state_seasonal_{s}"] += (
@@ -404,7 +409,7 @@ def create_structural_model_and_equivalent_statsmodel(
404409
components.append(comp)
405410

406411
if autoregressive is not None:
407-
ar_names = [f"L{i+1}" for i in range(autoregressive)]
412+
ar_names = [f"L{i + 1}" for i in range(autoregressive)]
408413
params_ar = rng.normal(size=(autoregressive,)).astype(floatX)
409414
if autoregressive == 1:
410415
params_ar = params_ar.item()
@@ -421,8 +426,8 @@ def create_structural_model_and_equivalent_statsmodel(
421426

422427
sm_params["sigma2.ar"] = sigma2
423428
for i, rho in enumerate(params_ar):
424-
sm_init[f"ar.L{i+1}"] = 0
425-
sm_params[f"ar.L{i+1}"] = rho
429+
sm_init[f"ar.L{i + 1}"] = 0
430+
sm_params[f"ar.L{i + 1}"] = rho
426431

427432
comp = st.AutoregressiveComponent(name="ar", order=autoregressive)
428433
components.append(comp)
@@ -439,7 +444,7 @@ def create_structural_model_and_equivalent_statsmodel(
439444

440445
for i, beta in enumerate(betas):
441446
sm_params[f"beta.x{i + 1}"] = beta
442-
sm_init[f"beta.x{i+1}"] = beta
447+
sm_init[f"beta.x{i + 1}"] = beta
443448
comp = st.RegressionComponent(name="exog", state_names=names)
444449
components.append(comp)
445450

tests/statespace/models/structural/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_add_components():
2828
"sigma_level_trend": np.ones(2, dtype=floatX),
2929
}
3030
se_params = {
31-
"coefs_seasonal": np.ones(11, dtype=floatX),
31+
"params_seasonal": np.ones(11, dtype=floatX),
3232
"sigma_seasonal": 1.0,
3333
}
3434
all_params = ll_params.copy()
@@ -97,7 +97,7 @@ def test_extract_components_from_idata(rng):
9797
beta_exog = pm.Normal("beta_exog", dims=["state_exog"])
9898
initial_trend = pm.Normal("initial_level_trend", dims=["state_level_trend"])
9999
sigma_trend = pm.Exponential("sigma_level_trend", 1, dims=["shock_level_trend"])
100-
seasonal_coefs = pm.Normal("seasonal", dims=["state_seasonal"])
100+
seasonal_coefs = pm.Normal("params_seasonal", dims=["state_seasonal"])
101101
sigma_obs = pm.Exponential("sigma_obs", 1)
102102

103103
mod.build_statespace_graph(y)
@@ -144,7 +144,7 @@ def test_extract_multiple_observed(rng):
144144
sigma_auto_regressive = pm.Normal("sigma_auto_regressive", dims=["endog_auto_regressive"])
145145
initial_trend = pm.Normal("initial_trend", dims=["endog_trend", "state_trend"])
146146
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "shock_trend"])
147-
seasonal_coefs = pm.Normal("seasonal", dims=["state_seasonal"])
147+
seasonal_coefs = pm.Normal("params_seasonal", dims=["state_seasonal"])
148148
sigma_obs = pm.Exponential("sigma_obs", 1, dims=["endog_obs"])
149149

150150
mod.build_statespace_graph(y)

0 commit comments

Comments
 (0)