1414RTOL = 0 if config .floatX .endswith ("64" ) else 1e-6
1515
1616
17- @pytest .fixture
18- def frequency_seasonality_params ():
19- """Common parameters for FrequencySeasonality tests."""
20- return {
21- "season_length" : 12 ,
22- "name" : "season" ,
23- "innovations" : True ,
24- }
25-
26-
27- @pytest .fixture
28- def time_seasonality_params ():
29- """Common parameters for TimeSeasonality tests."""
30- return {
31- "season_length" : 10 ,
32- "duration" : 1 ,
33- "innovations" : True ,
34- "name" : "season" ,
35- "remove_first_state" : True ,
36- }
37-
38-
39- def create_frequency_seasonality_model (** kwargs ):
40- """Helper function to create FrequencySeasonality models with common defaults."""
41- defaults = {
42- "season_length" : 12 ,
43- "n" : 2 ,
44- "name" : "season" ,
45- "innovations" : True ,
46- "observed_state_names" : ["data" ],
47- }
48- defaults .update (kwargs )
49- return FrequencySeasonality (** defaults )
50-
51-
52- def create_time_seasonality_model (** kwargs ):
53- """Helper function to create TimeSeasonality models with common defaults."""
54- defaults = {
55- "season_length" : 10 ,
56- "duration" : 1 ,
57- "innovations" : True ,
58- "name" : "season" ,
59- "remove_first_state" : True ,
60- "observed_state_names" : ["data" ],
61- }
62- defaults .update (kwargs )
63- return st .TimeSeasonality (** defaults )
64-
65-
66- def assert_coordinate_structure (model , expected_endog_names , expected_state_names ):
67- """Helper function to assert coordinate structure is correct."""
68- if len (expected_endog_names ) == 1 :
69- assert f"state_{ model .name } " in model .coords
70- assert model .coords [f"state_{ model .name } " ] == expected_state_names
71- else :
72- assert f"endog_{ model .name } " in model .coords
73- assert f"state_{ model .name } " in model .coords
74- assert model .coords [f"endog_{ model .name } " ] == expected_endog_names
75- assert model .coords [f"state_{ model .name } " ] == expected_state_names
76-
77-
78- def assert_parameter_structure (model , expected_shape , expected_dims = None ):
79- """Helper function to assert parameter structure is correct."""
80- param_name = f"params_{ model .name } "
81- assert param_name in model .param_info
82- assert model .param_info [param_name ]["shape" ] == expected_shape
83- if expected_dims :
84- assert model .param_info [param_name ]["dims" ] == expected_dims
85-
86-
8717@pytest .mark .parametrize ("s" , [10 , 25 , 50 ])
8818@pytest .mark .parametrize ("d" , [1 , 3 ])
8919@pytest .mark .parametrize ("innovations" , [True , False ])
@@ -313,6 +243,10 @@ def get_shift_factor(s):
313243@pytest .mark .parametrize ("s" , [5 , 10 , 25 , 25.2 ])
314244def test_frequency_seasonality (n , s , rng ):
315245 mod = st .FrequencySeasonality (season_length = s , n = n , name = "season" )
246+ assert mod .param_info ["sigma_season" ]["shape" ] == () # scalar for univariate
247+ assert mod .param_info ["sigma_season" ]["dims" ] is None
248+ assert len (mod .coords ["state_season" ]) == mod .n_coefs
249+
316250 x0 = rng .normal (size = mod .n_coefs ).astype (config .floatX )
317251 params = {"params_season" : x0 , "sigma_season" : 0.0 }
318252 k = get_shift_factor (s )
@@ -344,6 +278,10 @@ def test_frequency_seasonality_multiple_observed(rng):
344278 innovations = True ,
345279 observed_state_names = observed_state_names ,
346280 )
281+ assert mod .param_info ["params_season" ]["shape" ] == (mod .k_endog , mod .n_coefs )
282+ assert mod .param_info ["params_season" ]["dims" ] == ("endog_season" , "state_season" )
283+ assert mod .param_dims ["sigma_season" ] == ("endog_season" ,)
284+
347285 expected_state_names = [
348286 "Cos_0_season[data_1]" ,
349287 "Sin_0_season[data_1]" ,
@@ -521,71 +459,6 @@ def test_add_two_frequency_seasonality_different_observed(rng):
521459 np .testing .assert_allclose (expected_T , T_v , atol = ATOL , rtol = RTOL )
522460
523461
524- def test_time_seasonality_multivariate_parameter_shapes ():
525- """Test that TimeSeasonality correctly handles parameter shapes for multivariate data."""
526- mod_univariate = st .TimeSeasonality (
527- season_length = 4 ,
528- duration = 1 ,
529- innovations = True ,
530- name = "season" ,
531- observed_state_names = ["data" ],
532- )
533- mod_multivariate = st .TimeSeasonality (
534- season_length = 4 ,
535- duration = 1 ,
536- innovations = True ,
537- name = "season" ,
538- observed_state_names = ["data_1" , "data_2" ],
539- )
540-
541- assert mod_univariate .param_info ["sigma_season" ]["shape" ] == ()
542- assert mod_univariate .param_info ["sigma_season" ]["dims" ] is None
543-
544- assert mod_multivariate .param_info ["sigma_season" ]["shape" ] == (2 ,)
545- assert mod_multivariate .param_info ["sigma_season" ]["dims" ] == ("endog_season" ,)
546- assert mod_multivariate .param_dims ["sigma_season" ] == ("endog_season" ,)
547-
548-
549- def test_frequency_seasonality_multivariate_parameter_shapes ():
550- """Test that FrequencySeasonality correctly handles parameter shapes for multivariate data."""
551- mod_univariate = st .FrequencySeasonality (
552- season_length = 4 ,
553- n = 2 ,
554- innovations = True ,
555- name = "season" ,
556- observed_state_names = ["data" ],
557- )
558- mod_multivariate = st .FrequencySeasonality (
559- season_length = 4 ,
560- n = 2 ,
561- innovations = True ,
562- name = "season" ,
563- observed_state_names = ["data_1" , "data_2" ],
564- )
565-
566- assert mod_univariate .param_info ["sigma_season" ]["shape" ] == () # scalar for univariate
567- assert mod_univariate .param_info ["sigma_season" ]["dims" ] is None
568-
569- assert mod_multivariate .param_info ["sigma_season" ]["shape" ] == (
570- 2 ,
571- ) # one value per endog variable
572- assert mod_multivariate .param_info ["sigma_season" ]["dims" ] == ("endog_season" ,)
573-
574- # test with different n values
575- mod_multivariate_n1 = st .FrequencySeasonality (
576- season_length = 4 ,
577- n = 1 ,
578- innovations = True ,
579- name = "season" ,
580- observed_state_names = ["data_1" , "data_2" ],
581- )
582-
583- assert mod_multivariate_n1 .param_info ["sigma_season" ]["shape" ] == (
584- 2 ,
585- ) # one value per endog variable
586- assert mod_multivariate_n1 .param_info ["sigma_season" ]["dims" ] == ("endog_season" ,)
587-
588-
589462@pytest .mark .parametrize (
590463 "test_case" ,
591464 [
@@ -617,42 +490,6 @@ def test_frequency_seasonality_multivariate_parameter_shapes():
617490 "observed_state_names" : ["data1" , "data2" ],
618491 "expected_shape" : (2 , 11 ),
619492 },
620- ],
621- )
622- def test_frequency_seasonality_coordinates (test_case ):
623- """Test that coordinate determination works correctly for different scenarios."""
624-
625- model_name = f"season_{ test_case ['name' ].split ('_' )[0 ]} "
626-
627- season = FrequencySeasonality (
628- season_length = test_case ["season_length" ],
629- n = test_case ["n" ],
630- name = model_name ,
631- observed_state_names = test_case ["observed_state_names" ],
632- )
633- season .populate_component_properties ()
634-
635- # assert parameter shape
636- assert season .param_info [f"params_{ model_name } " ]["shape" ] == test_case ["expected_shape" ]
637-
638- # generate expected state names based on actual model name
639- expected_state_names = [
640- f"{ f } _{ i } _{ model_name } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
641- ][: test_case ["expected_shape" ][- 1 ]]
642-
643- # assert coordinate structure
644- if len (test_case ["observed_state_names" ]) == 1 :
645- assert len (season .coords [f"state_{ model_name } " ]) == test_case ["expected_shape" ][0 ]
646- assert season .coords [f"state_{ model_name } " ] == expected_state_names
647- else :
648- assert len (season .coords [f"endog_{ model_name } " ]) == test_case ["expected_shape" ][0 ]
649- assert len (season .coords [f"state_{ model_name } " ]) == test_case ["expected_shape" ][1 ]
650- assert season .coords [f"state_{ model_name } " ] == expected_state_names
651-
652-
653- @pytest .mark .parametrize (
654- "test_case" ,
655- [
656493 {
657494 "name" : "small_n" ,
658495 "season_length" : 12 ,
@@ -668,10 +505,9 @@ def test_frequency_seasonality_coordinates(test_case):
668505 "expected_shape" : (4 , 4 ),
669506 },
670507 ],
508+ ids = lambda x : x ["name" ],
671509)
672- def test_frequency_seasonality_edge_cases (test_case ):
673- """Test edge cases for coordinate determination."""
674-
510+ def test_frequency_seasonality_coordinates (test_case ):
675511 model_name = f"season_{ test_case ['name' ].split ('_' )[0 ]} "
676512
677513 season = FrequencySeasonality (
@@ -699,21 +535,11 @@ def test_frequency_seasonality_edge_cases(test_case):
699535 assert len (season .coords [f"state_{ model_name } " ]) == test_case ["expected_shape" ][1 ]
700536 assert season .coords [f"state_{ model_name } " ] == expected_state_names
701537
538+ # Check coords match the expected shape
539+ param_shape = season .param_info [f"params_{ model_name } " ]["shape" ]
540+ state_coords = season .coords [f"state_{ model_name } " ]
541+ endog_coords = season .coords .get (f"endog_{ model_name } " )
702542
703- def test_frequency_seasonality_parameter_consistency ():
704- """Test that parameter shapes and coordinates are consistent."""
705-
706- season = FrequencySeasonality (
707- season_length = 12 , n = 3 , name = "season" , observed_state_names = ["data1" , "data2" ]
708- )
709- season .populate_component_properties ()
710-
711- param_shape = season .param_info ["params_season" ]["shape" ]
712- state_coords = season .coords ["state_season" ]
713- endog_coords = season .coords ["endog_season" ]
714-
715- # for shape (k_endog, n_coefs), we should have:
716- # - len(endog_coords) == k_endog
717- # - len(state_coords) == n_coefs
718- assert len (endog_coords ) == param_shape [0 ]
719- assert len (state_coords ) == param_shape [1 ]
543+ assert len (state_coords ) == param_shape [- 1 ]
544+ if endog_coords :
545+ assert len (endog_coords ) == param_shape [0 ]
0 commit comments