14
14
RTOL = 0 if config .floatX .endswith ("64" ) else 1e-6
15
15
16
16
17
- @pytest .fixture
18
- def frequency_seasonality_params ():
19
- """Common parameters for FrequencySeasonality tests."""
20
- return {
21
- "season_length" : 12 ,
22
- "name" : "season" ,
23
- "innovations" : True ,
24
- }
25
-
26
-
27
- @pytest .fixture
28
- def time_seasonality_params ():
29
- """Common parameters for TimeSeasonality tests."""
30
- return {
31
- "season_length" : 10 ,
32
- "duration" : 1 ,
33
- "innovations" : True ,
34
- "name" : "season" ,
35
- "remove_first_state" : True ,
36
- }
37
-
38
-
39
- def create_frequency_seasonality_model (** kwargs ):
40
- """Helper function to create FrequencySeasonality models with common defaults."""
41
- defaults = {
42
- "season_length" : 12 ,
43
- "n" : 2 ,
44
- "name" : "season" ,
45
- "innovations" : True ,
46
- "observed_state_names" : ["data" ],
47
- }
48
- defaults .update (kwargs )
49
- return FrequencySeasonality (** defaults )
50
-
51
-
52
- def create_time_seasonality_model (** kwargs ):
53
- """Helper function to create TimeSeasonality models with common defaults."""
54
- defaults = {
55
- "season_length" : 10 ,
56
- "duration" : 1 ,
57
- "innovations" : True ,
58
- "name" : "season" ,
59
- "remove_first_state" : True ,
60
- "observed_state_names" : ["data" ],
61
- }
62
- defaults .update (kwargs )
63
- return st .TimeSeasonality (** defaults )
64
-
65
-
66
- def assert_coordinate_structure (model , expected_endog_names , expected_state_names ):
67
- """Helper function to assert coordinate structure is correct."""
68
- if len (expected_endog_names ) == 1 :
69
- assert f"state_{ model .name } " in model .coords
70
- assert model .coords [f"state_{ model .name } " ] == expected_state_names
71
- else :
72
- assert f"endog_{ model .name } " in model .coords
73
- assert f"state_{ model .name } " in model .coords
74
- assert model .coords [f"endog_{ model .name } " ] == expected_endog_names
75
- assert model .coords [f"state_{ model .name } " ] == expected_state_names
76
-
77
-
78
- def assert_parameter_structure (model , expected_shape , expected_dims = None ):
79
- """Helper function to assert parameter structure is correct."""
80
- param_name = f"params_{ model .name } "
81
- assert param_name in model .param_info
82
- assert model .param_info [param_name ]["shape" ] == expected_shape
83
- if expected_dims :
84
- assert model .param_info [param_name ]["dims" ] == expected_dims
85
-
86
-
87
17
@pytest .mark .parametrize ("s" , [10 , 25 , 50 ])
88
18
@pytest .mark .parametrize ("d" , [1 , 3 ])
89
19
@pytest .mark .parametrize ("innovations" , [True , False ])
@@ -313,6 +243,10 @@ def get_shift_factor(s):
313
243
@pytest .mark .parametrize ("s" , [5 , 10 , 25 , 25.2 ])
314
244
def test_frequency_seasonality (n , s , rng ):
315
245
mod = st .FrequencySeasonality (season_length = s , n = n , name = "season" )
246
+ assert mod .param_info ["sigma_season" ]["shape" ] == () # scalar for univariate
247
+ assert mod .param_info ["sigma_season" ]["dims" ] is None
248
+ assert len (mod .coords ["state_season" ]) == mod .n_coefs
249
+
316
250
x0 = rng .normal (size = mod .n_coefs ).astype (config .floatX )
317
251
params = {"params_season" : x0 , "sigma_season" : 0.0 }
318
252
k = get_shift_factor (s )
@@ -344,6 +278,10 @@ def test_frequency_seasonality_multiple_observed(rng):
344
278
innovations = True ,
345
279
observed_state_names = observed_state_names ,
346
280
)
281
+ assert mod .param_info ["params_season" ]["shape" ] == (mod .k_endog , mod .n_coefs )
282
+ assert mod .param_info ["params_season" ]["dims" ] == ("endog_season" , "state_season" )
283
+ assert mod .param_dims ["sigma_season" ] == ("endog_season" ,)
284
+
347
285
expected_state_names = [
348
286
"Cos_0_season[data_1]" ,
349
287
"Sin_0_season[data_1]" ,
@@ -521,71 +459,6 @@ def test_add_two_frequency_seasonality_different_observed(rng):
521
459
np .testing .assert_allclose (expected_T , T_v , atol = ATOL , rtol = RTOL )
522
460
523
461
524
- def test_time_seasonality_multivariate_parameter_shapes ():
525
- """Test that TimeSeasonality correctly handles parameter shapes for multivariate data."""
526
- mod_univariate = st .TimeSeasonality (
527
- season_length = 4 ,
528
- duration = 1 ,
529
- innovations = True ,
530
- name = "season" ,
531
- observed_state_names = ["data" ],
532
- )
533
- mod_multivariate = st .TimeSeasonality (
534
- season_length = 4 ,
535
- duration = 1 ,
536
- innovations = True ,
537
- name = "season" ,
538
- observed_state_names = ["data_1" , "data_2" ],
539
- )
540
-
541
- assert mod_univariate .param_info ["sigma_season" ]["shape" ] == ()
542
- assert mod_univariate .param_info ["sigma_season" ]["dims" ] is None
543
-
544
- assert mod_multivariate .param_info ["sigma_season" ]["shape" ] == (2 ,)
545
- assert mod_multivariate .param_info ["sigma_season" ]["dims" ] == ("endog_season" ,)
546
- assert mod_multivariate .param_dims ["sigma_season" ] == ("endog_season" ,)
547
-
548
-
549
- def test_frequency_seasonality_multivariate_parameter_shapes ():
550
- """Test that FrequencySeasonality correctly handles parameter shapes for multivariate data."""
551
- mod_univariate = st .FrequencySeasonality (
552
- season_length = 4 ,
553
- n = 2 ,
554
- innovations = True ,
555
- name = "season" ,
556
- observed_state_names = ["data" ],
557
- )
558
- mod_multivariate = st .FrequencySeasonality (
559
- season_length = 4 ,
560
- n = 2 ,
561
- innovations = True ,
562
- name = "season" ,
563
- observed_state_names = ["data_1" , "data_2" ],
564
- )
565
-
566
- assert mod_univariate .param_info ["sigma_season" ]["shape" ] == () # scalar for univariate
567
- assert mod_univariate .param_info ["sigma_season" ]["dims" ] is None
568
-
569
- assert mod_multivariate .param_info ["sigma_season" ]["shape" ] == (
570
- 2 ,
571
- ) # one value per endog variable
572
- assert mod_multivariate .param_info ["sigma_season" ]["dims" ] == ("endog_season" ,)
573
-
574
- # test with different n values
575
- mod_multivariate_n1 = st .FrequencySeasonality (
576
- season_length = 4 ,
577
- n = 1 ,
578
- innovations = True ,
579
- name = "season" ,
580
- observed_state_names = ["data_1" , "data_2" ],
581
- )
582
-
583
- assert mod_multivariate_n1 .param_info ["sigma_season" ]["shape" ] == (
584
- 2 ,
585
- ) # one value per endog variable
586
- assert mod_multivariate_n1 .param_info ["sigma_season" ]["dims" ] == ("endog_season" ,)
587
-
588
-
589
462
@pytest .mark .parametrize (
590
463
"test_case" ,
591
464
[
@@ -617,42 +490,6 @@ def test_frequency_seasonality_multivariate_parameter_shapes():
617
490
"observed_state_names" : ["data1" , "data2" ],
618
491
"expected_shape" : (2 , 11 ),
619
492
},
620
- ],
621
- )
622
- def test_frequency_seasonality_coordinates (test_case ):
623
- """Test that coordinate determination works correctly for different scenarios."""
624
-
625
- model_name = f"season_{ test_case ['name' ].split ('_' )[0 ]} "
626
-
627
- season = FrequencySeasonality (
628
- season_length = test_case ["season_length" ],
629
- n = test_case ["n" ],
630
- name = model_name ,
631
- observed_state_names = test_case ["observed_state_names" ],
632
- )
633
- season .populate_component_properties ()
634
-
635
- # assert parameter shape
636
- assert season .param_info [f"params_{ model_name } " ]["shape" ] == test_case ["expected_shape" ]
637
-
638
- # generate expected state names based on actual model name
639
- expected_state_names = [
640
- f"{ f } _{ i } _{ model_name } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
641
- ][: test_case ["expected_shape" ][- 1 ]]
642
-
643
- # assert coordinate structure
644
- if len (test_case ["observed_state_names" ]) == 1 :
645
- assert len (season .coords [f"state_{ model_name } " ]) == test_case ["expected_shape" ][0 ]
646
- assert season .coords [f"state_{ model_name } " ] == expected_state_names
647
- else :
648
- assert len (season .coords [f"endog_{ model_name } " ]) == test_case ["expected_shape" ][0 ]
649
- assert len (season .coords [f"state_{ model_name } " ]) == test_case ["expected_shape" ][1 ]
650
- assert season .coords [f"state_{ model_name } " ] == expected_state_names
651
-
652
-
653
- @pytest .mark .parametrize (
654
- "test_case" ,
655
- [
656
493
{
657
494
"name" : "small_n" ,
658
495
"season_length" : 12 ,
@@ -668,10 +505,9 @@ def test_frequency_seasonality_coordinates(test_case):
668
505
"expected_shape" : (4 , 4 ),
669
506
},
670
507
],
508
+ ids = lambda x : x ["name" ],
671
509
)
672
- def test_frequency_seasonality_edge_cases (test_case ):
673
- """Test edge cases for coordinate determination."""
674
-
510
+ def test_frequency_seasonality_coordinates (test_case ):
675
511
model_name = f"season_{ test_case ['name' ].split ('_' )[0 ]} "
676
512
677
513
season = FrequencySeasonality (
@@ -699,21 +535,11 @@ def test_frequency_seasonality_edge_cases(test_case):
699
535
assert len (season .coords [f"state_{ model_name } " ]) == test_case ["expected_shape" ][1 ]
700
536
assert season .coords [f"state_{ model_name } " ] == expected_state_names
701
537
538
+ # Check coords match the expected shape
539
+ param_shape = season .param_info [f"params_{ model_name } " ]["shape" ]
540
+ state_coords = season .coords [f"state_{ model_name } " ]
541
+ endog_coords = season .coords .get (f"endog_{ model_name } " )
702
542
703
- def test_frequency_seasonality_parameter_consistency ():
704
- """Test that parameter shapes and coordinates are consistent."""
705
-
706
- season = FrequencySeasonality (
707
- season_length = 12 , n = 3 , name = "season" , observed_state_names = ["data1" , "data2" ]
708
- )
709
- season .populate_component_properties ()
710
-
711
- param_shape = season .param_info ["params_season" ]["shape" ]
712
- state_coords = season .coords ["state_season" ]
713
- endog_coords = season .coords ["endog_season" ]
714
-
715
- # for shape (k_endog, n_coefs), we should have:
716
- # - len(endog_coords) == k_endog
717
- # - len(state_coords) == n_coefs
718
- assert len (endog_coords ) == param_shape [0 ]
719
- assert len (state_coords ) == param_shape [1 ]
543
+ assert len (state_coords ) == param_shape [- 1 ]
544
+ if endog_coords :
545
+ assert len (endog_coords ) == param_shape [0 ]
0 commit comments