Skip to content

Commit d944d62

Browse files
Update tests
1 parent aa3a011 commit d944d62

File tree

2 files changed

+35
-30
lines changed

2 files changed

+35
-30
lines changed

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_frequency_seasonality(n, s, rng):
326326
_assert_basic_coords_correct(mod)
327327
if n is None:
328328
n = int(s // 2)
329-
states = [f"{f}_season_{i}" for i in range(n) for f in ["Cos", "Sin"]]
329+
states = [f"{f}_{i}_season" for i in range(n) for f in ["Cos", "Sin"]]
330330

331331
# remove last state when model is completely saturated
332332
if s / n == 2.0:
@@ -345,19 +345,25 @@ def test_frequency_seasonality_multiple_observed(rng):
345345
observed_state_names=observed_state_names,
346346
)
347347
expected_state_names = [
348-
"Cos_season_0[data_1]",
349-
"Sin_season_0[data_1]",
350-
"Cos_season_1[data_1]",
351-
"Sin_season_1[data_1]",
352-
"Cos_season_0[data_2]",
353-
"Sin_season_0[data_2]",
354-
"Cos_season_1[data_2]",
355-
"Sin_season_1[data_2]",
348+
"Cos_0_season[data_1]",
349+
"Sin_0_season[data_1]",
350+
"Cos_1_season[data_1]",
351+
"Sin_1_season[data_1]",
352+
"Cos_0_season[data_2]",
353+
"Sin_0_season[data_2]",
354+
"Cos_1_season[data_2]",
355+
"Sin_1_season[data_2]",
356356
]
357357
assert mod.state_names == expected_state_names
358358
assert mod.shock_names == [
359-
"season[data_1]",
360-
"season[data_2]",
359+
"Cos_0_season[data_1]",
360+
"Sin_0_season[data_1]",
361+
"Cos_1_season[data_1]",
362+
"Sin_1_season[data_1]",
363+
"Cos_0_season[data_2]",
364+
"Sin_0_season[data_2]",
365+
"Cos_1_season[data_2]",
366+
"Sin_1_season[data_2]",
361367
]
362368

363369
x0 = np.zeros((2, 3), dtype=config.floatX)
@@ -372,9 +378,9 @@ def test_frequency_seasonality_multiple_observed(rng):
372378

373379
mod = mod.build(verbose=False)
374380
assert list(mod.coords["state_season"]) == [
375-
"Cos_season_0",
376-
"Sin_season_0",
377-
"Cos_season_1",
381+
"Cos_0_season",
382+
"Sin_0_season",
383+
"Cos_1_season",
378384
]
379385

380386
x0_sym, *_, T_sym, Z_sym, R_sym, _, Q_sym = mod._unpack_statespace_with_placeholders()
@@ -460,17 +466,21 @@ def test_add_two_frequency_seasonality_different_observed(rng):
460466
assert_pattern_repeats(y[:, 1], 6, atol=ATOL, rtol=RTOL)
461467

462468
assert mod.state_names == [
463-
"Cos_freq1_0[data_1]",
464-
"Sin_freq1_0[data_1]",
465-
"Cos_freq1_1[data_1]",
466-
"Sin_freq1_1[data_1]",
467-
"Cos_freq2_0[data_2]",
468-
"Sin_freq2_0[data_2]",
469+
"Cos_0_freq1[data_1]",
470+
"Sin_0_freq1[data_1]",
471+
"Cos_1_freq1[data_1]",
472+
"Sin_1_freq1[data_1]",
473+
"Cos_0_freq2[data_2]",
474+
"Sin_0_freq2[data_2]",
469475
]
470476

471477
assert mod.shock_names == [
472-
"freq1[data_1]",
473-
"freq2[data_2]",
478+
"Cos_0_freq1[data_1]",
479+
"Sin_0_freq1[data_1]",
480+
"Cos_1_freq1[data_1]",
481+
"Sin_1_freq1[data_1]",
482+
"Cos_0_freq2[data_2]",
483+
"Sin_0_freq2[data_2]",
474484
]
475485

476486
x0, *_, T = mod._unpack_statespace_with_placeholders()[:5]
@@ -627,7 +637,7 @@ def test_frequency_seasonality_coordinates(test_case):
627637

628638
# generate expected state names based on actual model name
629639
expected_state_names = [
630-
f"{f}_{model_name}_{i}" for i in range(test_case["n"]) for f in ["Cos", "Sin"]
640+
f"{f}_{i}_{model_name}" for i in range(test_case["n"]) for f in ["Cos", "Sin"]
631641
][: test_case["expected_shape"][-1]]
632642

633643
# assert coordinate structure
@@ -677,7 +687,7 @@ def test_frequency_seasonality_edge_cases(test_case):
677687

678688
# generate expected state names based on actual model name
679689
expected_state_names = [
680-
f"{f}_{model_name}_{i}" for i in range(test_case["n"]) for f in ["Cos", "Sin"]
690+
f"{f}_{i}_{model_name}" for i in range(test_case["n"]) for f in ["Cos", "Sin"]
681691
][: test_case["expected_shape"][-1]]
682692

683693
# assert coordinate structure

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ def _assert_coord_shapes_match_matrices(mod, params):
6767
n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
6868
n_obs = len(mod.coords[OBS_STATE_DIM])
6969

70-
print(f"{mod.coords[ALL_STATE_DIM] = }")
71-
print(f"{mod.coords[SHOCK_DIM] = }")
72-
print(f"{mod.coords[OBS_STATE_DIM] = }")
73-
print(f"{R = }")
74-
7570
assert x0.shape[-1:] == (
7671
n_states,
7772
), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
@@ -336,7 +331,7 @@ def create_structural_model_and_equivalent_statsmodel(
336331
s = d["period"]
337332
last_state_not_identified = (s / n) == 2.0
338333
n_states = 2 * n - int(last_state_not_identified)
339-
state_names = [f"{f}_seasonal_{s}_{i}" for i in range(n) for f in ["Cos", "Sin"]]
334+
state_names = [f"{f}_{i}_seasonal_{s}" for i in range(n) for f in ["Cos", "Sin"]]
340335

341336
seasonal_params = rng.normal(size=n_states).astype(floatX)
342337

0 commit comments

Comments
 (0)