Skip to content

Commit c5ace74

Browse files
Simplify tests
1 parent d944d62 commit c5ace74

File tree

1 file changed

+17
-191
lines changed

1 file changed

+17
-191
lines changed

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 17 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -14,76 +14,6 @@
1414
RTOL = 0 if config.floatX.endswith("64") else 1e-6
1515

1616

17-
@pytest.fixture
18-
def frequency_seasonality_params():
19-
"""Common parameters for FrequencySeasonality tests."""
20-
return {
21-
"season_length": 12,
22-
"name": "season",
23-
"innovations": True,
24-
}
25-
26-
27-
@pytest.fixture
28-
def time_seasonality_params():
29-
"""Common parameters for TimeSeasonality tests."""
30-
return {
31-
"season_length": 10,
32-
"duration": 1,
33-
"innovations": True,
34-
"name": "season",
35-
"remove_first_state": True,
36-
}
37-
38-
39-
def create_frequency_seasonality_model(**kwargs):
40-
"""Helper function to create FrequencySeasonality models with common defaults."""
41-
defaults = {
42-
"season_length": 12,
43-
"n": 2,
44-
"name": "season",
45-
"innovations": True,
46-
"observed_state_names": ["data"],
47-
}
48-
defaults.update(kwargs)
49-
return FrequencySeasonality(**defaults)
50-
51-
52-
def create_time_seasonality_model(**kwargs):
53-
"""Helper function to create TimeSeasonality models with common defaults."""
54-
defaults = {
55-
"season_length": 10,
56-
"duration": 1,
57-
"innovations": True,
58-
"name": "season",
59-
"remove_first_state": True,
60-
"observed_state_names": ["data"],
61-
}
62-
defaults.update(kwargs)
63-
return st.TimeSeasonality(**defaults)
64-
65-
66-
def assert_coordinate_structure(model, expected_endog_names, expected_state_names):
67-
"""Helper function to assert coordinate structure is correct."""
68-
if len(expected_endog_names) == 1:
69-
assert f"state_{model.name}" in model.coords
70-
assert model.coords[f"state_{model.name}"] == expected_state_names
71-
else:
72-
assert f"endog_{model.name}" in model.coords
73-
assert f"state_{model.name}" in model.coords
74-
assert model.coords[f"endog_{model.name}"] == expected_endog_names
75-
assert model.coords[f"state_{model.name}"] == expected_state_names
76-
77-
78-
def assert_parameter_structure(model, expected_shape, expected_dims=None):
79-
"""Helper function to assert parameter structure is correct."""
80-
param_name = f"params_{model.name}"
81-
assert param_name in model.param_info
82-
assert model.param_info[param_name]["shape"] == expected_shape
83-
if expected_dims:
84-
assert model.param_info[param_name]["dims"] == expected_dims
85-
86-
8717
@pytest.mark.parametrize("s", [10, 25, 50])
8818
@pytest.mark.parametrize("d", [1, 3])
8919
@pytest.mark.parametrize("innovations", [True, False])
@@ -313,6 +243,10 @@ def get_shift_factor(s):
313243
@pytest.mark.parametrize("s", [5, 10, 25, 25.2])
314244
def test_frequency_seasonality(n, s, rng):
315245
mod = st.FrequencySeasonality(season_length=s, n=n, name="season")
246+
assert mod.param_info["sigma_season"]["shape"] == () # scalar for univariate
247+
assert mod.param_info["sigma_season"]["dims"] is None
248+
assert len(mod.coords["state_season"]) == mod.n_coefs
249+
316250
x0 = rng.normal(size=mod.n_coefs).astype(config.floatX)
317251
params = {"params_season": x0, "sigma_season": 0.0}
318252
k = get_shift_factor(s)
@@ -344,6 +278,10 @@ def test_frequency_seasonality_multiple_observed(rng):
344278
innovations=True,
345279
observed_state_names=observed_state_names,
346280
)
281+
assert mod.param_info["params_season"]["shape"] == (mod.k_endog, mod.n_coefs)
282+
assert mod.param_info["params_season"]["dims"] == ("endog_season", "state_season")
283+
assert mod.param_dims["sigma_season"] == ("endog_season",)
284+
347285
expected_state_names = [
348286
"Cos_0_season[data_1]",
349287
"Sin_0_season[data_1]",
@@ -521,71 +459,6 @@ def test_add_two_frequency_seasonality_different_observed(rng):
521459
np.testing.assert_allclose(expected_T, T_v, atol=ATOL, rtol=RTOL)
522460

523461

524-
def test_time_seasonality_multivariate_parameter_shapes():
525-
"""Test that TimeSeasonality correctly handles parameter shapes for multivariate data."""
526-
mod_univariate = st.TimeSeasonality(
527-
season_length=4,
528-
duration=1,
529-
innovations=True,
530-
name="season",
531-
observed_state_names=["data"],
532-
)
533-
mod_multivariate = st.TimeSeasonality(
534-
season_length=4,
535-
duration=1,
536-
innovations=True,
537-
name="season",
538-
observed_state_names=["data_1", "data_2"],
539-
)
540-
541-
assert mod_univariate.param_info["sigma_season"]["shape"] == ()
542-
assert mod_univariate.param_info["sigma_season"]["dims"] is None
543-
544-
assert mod_multivariate.param_info["sigma_season"]["shape"] == (2,)
545-
assert mod_multivariate.param_info["sigma_season"]["dims"] == ("endog_season",)
546-
assert mod_multivariate.param_dims["sigma_season"] == ("endog_season",)
547-
548-
549-
def test_frequency_seasonality_multivariate_parameter_shapes():
550-
"""Test that FrequencySeasonality correctly handles parameter shapes for multivariate data."""
551-
mod_univariate = st.FrequencySeasonality(
552-
season_length=4,
553-
n=2,
554-
innovations=True,
555-
name="season",
556-
observed_state_names=["data"],
557-
)
558-
mod_multivariate = st.FrequencySeasonality(
559-
season_length=4,
560-
n=2,
561-
innovations=True,
562-
name="season",
563-
observed_state_names=["data_1", "data_2"],
564-
)
565-
566-
assert mod_univariate.param_info["sigma_season"]["shape"] == () # scalar for univariate
567-
assert mod_univariate.param_info["sigma_season"]["dims"] is None
568-
569-
assert mod_multivariate.param_info["sigma_season"]["shape"] == (
570-
2,
571-
) # one value per endog variable
572-
assert mod_multivariate.param_info["sigma_season"]["dims"] == ("endog_season",)
573-
574-
# test with different n values
575-
mod_multivariate_n1 = st.FrequencySeasonality(
576-
season_length=4,
577-
n=1,
578-
innovations=True,
579-
name="season",
580-
observed_state_names=["data_1", "data_2"],
581-
)
582-
583-
assert mod_multivariate_n1.param_info["sigma_season"]["shape"] == (
584-
2,
585-
) # one value per endog variable
586-
assert mod_multivariate_n1.param_info["sigma_season"]["dims"] == ("endog_season",)
587-
588-
589462
@pytest.mark.parametrize(
590463
"test_case",
591464
[
@@ -617,42 +490,6 @@ def test_frequency_seasonality_multivariate_parameter_shapes():
617490
"observed_state_names": ["data1", "data2"],
618491
"expected_shape": (2, 11),
619492
},
620-
],
621-
)
622-
def test_frequency_seasonality_coordinates(test_case):
623-
"""Test that coordinate determination works correctly for different scenarios."""
624-
625-
model_name = f"season_{test_case['name'].split('_')[0]}"
626-
627-
season = FrequencySeasonality(
628-
season_length=test_case["season_length"],
629-
n=test_case["n"],
630-
name=model_name,
631-
observed_state_names=test_case["observed_state_names"],
632-
)
633-
season.populate_component_properties()
634-
635-
# assert parameter shape
636-
assert season.param_info[f"params_{model_name}"]["shape"] == test_case["expected_shape"]
637-
638-
# generate expected state names based on actual model name
639-
expected_state_names = [
640-
f"{f}_{i}_{model_name}" for i in range(test_case["n"]) for f in ["Cos", "Sin"]
641-
][: test_case["expected_shape"][-1]]
642-
643-
# assert coordinate structure
644-
if len(test_case["observed_state_names"]) == 1:
645-
assert len(season.coords[f"state_{model_name}"]) == test_case["expected_shape"][0]
646-
assert season.coords[f"state_{model_name}"] == expected_state_names
647-
else:
648-
assert len(season.coords[f"endog_{model_name}"]) == test_case["expected_shape"][0]
649-
assert len(season.coords[f"state_{model_name}"]) == test_case["expected_shape"][1]
650-
assert season.coords[f"state_{model_name}"] == expected_state_names
651-
652-
653-
@pytest.mark.parametrize(
654-
"test_case",
655-
[
656493
{
657494
"name": "small_n",
658495
"season_length": 12,
@@ -668,10 +505,9 @@ def test_frequency_seasonality_coordinates(test_case):
668505
"expected_shape": (4, 4),
669506
},
670507
],
508+
ids=lambda x: x["name"],
671509
)
672-
def test_frequency_seasonality_edge_cases(test_case):
673-
"""Test edge cases for coordinate determination."""
674-
510+
def test_frequency_seasonality_coordinates(test_case):
675511
model_name = f"season_{test_case['name'].split('_')[0]}"
676512

677513
season = FrequencySeasonality(
@@ -699,21 +535,11 @@ def test_frequency_seasonality_edge_cases(test_case):
699535
assert len(season.coords[f"state_{model_name}"]) == test_case["expected_shape"][1]
700536
assert season.coords[f"state_{model_name}"] == expected_state_names
701537

538+
# Check coords match the expected shape
539+
param_shape = season.param_info[f"params_{model_name}"]["shape"]
540+
state_coords = season.coords[f"state_{model_name}"]
541+
endog_coords = season.coords.get(f"endog_{model_name}")
702542

703-
def test_frequency_seasonality_parameter_consistency():
704-
"""Test that parameter shapes and coordinates are consistent."""
705-
706-
season = FrequencySeasonality(
707-
season_length=12, n=3, name="season", observed_state_names=["data1", "data2"]
708-
)
709-
season.populate_component_properties()
710-
711-
param_shape = season.param_info["params_season"]["shape"]
712-
state_coords = season.coords["state_season"]
713-
endog_coords = season.coords["endog_season"]
714-
715-
# for shape (k_endog, n_coefs), we should have:
716-
# - len(endog_coords) == k_endog
717-
# - len(state_coords) == n_coefs
718-
assert len(endog_coords) == param_shape[0]
719-
assert len(state_coords) == param_shape[1]
543+
assert len(state_coords) == param_shape[-1]
544+
if endog_coords:
545+
assert len(endog_coords) == param_shape[0]

0 commit comments

Comments
 (0)