@@ -326,7 +326,7 @@ def test_frequency_seasonality(n, s, rng):
326
326
_assert_basic_coords_correct (mod )
327
327
if n is None :
328
328
n = int (s // 2 )
329
- states = [f"{ f } _season_ { i } " for i in range (n ) for f in ["Cos" , "Sin" ]]
329
+ states = [f"{ f } _ { i } _season " for i in range (n ) for f in ["Cos" , "Sin" ]]
330
330
331
331
# remove last state when model is completely saturated
332
332
if s / n == 2.0 :
@@ -345,19 +345,25 @@ def test_frequency_seasonality_multiple_observed(rng):
345
345
observed_state_names = observed_state_names ,
346
346
)
347
347
expected_state_names = [
348
- "Cos_season_0 [data_1]" ,
349
- "Sin_season_0 [data_1]" ,
350
- "Cos_season_1 [data_1]" ,
351
- "Sin_season_1 [data_1]" ,
352
- "Cos_season_0 [data_2]" ,
353
- "Sin_season_0 [data_2]" ,
354
- "Cos_season_1 [data_2]" ,
355
- "Sin_season_1 [data_2]" ,
348
+ "Cos_0_season [data_1]" ,
349
+ "Sin_0_season [data_1]" ,
350
+ "Cos_1_season [data_1]" ,
351
+ "Sin_1_season [data_1]" ,
352
+ "Cos_0_season [data_2]" ,
353
+ "Sin_0_season [data_2]" ,
354
+ "Cos_1_season [data_2]" ,
355
+ "Sin_1_season [data_2]" ,
356
356
]
357
357
assert mod .state_names == expected_state_names
358
358
assert mod .shock_names == [
359
- "season[data_1]" ,
360
- "season[data_2]" ,
359
+ "Cos_0_season[data_1]" ,
360
+ "Sin_0_season[data_1]" ,
361
+ "Cos_1_season[data_1]" ,
362
+ "Sin_1_season[data_1]" ,
363
+ "Cos_0_season[data_2]" ,
364
+ "Sin_0_season[data_2]" ,
365
+ "Cos_1_season[data_2]" ,
366
+ "Sin_1_season[data_2]" ,
361
367
]
362
368
363
369
x0 = np .zeros ((2 , 3 ), dtype = config .floatX )
@@ -372,9 +378,9 @@ def test_frequency_seasonality_multiple_observed(rng):
372
378
373
379
mod = mod .build (verbose = False )
374
380
assert list (mod .coords ["state_season" ]) == [
375
- "Cos_season_0 " ,
376
- "Sin_season_0 " ,
377
- "Cos_season_1 " ,
381
+ "Cos_0_season " ,
382
+ "Sin_0_season " ,
383
+ "Cos_1_season " ,
378
384
]
379
385
380
386
x0_sym , * _ , T_sym , Z_sym , R_sym , _ , Q_sym = mod ._unpack_statespace_with_placeholders ()
@@ -460,17 +466,21 @@ def test_add_two_frequency_seasonality_different_observed(rng):
460
466
assert_pattern_repeats (y [:, 1 ], 6 , atol = ATOL , rtol = RTOL )
461
467
462
468
assert mod .state_names == [
463
- "Cos_freq1_0 [data_1]" ,
464
- "Sin_freq1_0 [data_1]" ,
465
- "Cos_freq1_1 [data_1]" ,
466
- "Sin_freq1_1 [data_1]" ,
467
- "Cos_freq2_0 [data_2]" ,
468
- "Sin_freq2_0 [data_2]" ,
469
+ "Cos_0_freq1 [data_1]" ,
470
+ "Sin_0_freq1 [data_1]" ,
471
+ "Cos_1_freq1 [data_1]" ,
472
+ "Sin_1_freq1 [data_1]" ,
473
+ "Cos_0_freq2 [data_2]" ,
474
+ "Sin_0_freq2 [data_2]" ,
469
475
]
470
476
471
477
assert mod .shock_names == [
472
- "freq1[data_1]" ,
473
- "freq2[data_2]" ,
478
+ "Cos_0_freq1[data_1]" ,
479
+ "Sin_0_freq1[data_1]" ,
480
+ "Cos_1_freq1[data_1]" ,
481
+ "Sin_1_freq1[data_1]" ,
482
+ "Cos_0_freq2[data_2]" ,
483
+ "Sin_0_freq2[data_2]" ,
474
484
]
475
485
476
486
x0 , * _ , T = mod ._unpack_statespace_with_placeholders ()[:5 ]
@@ -627,7 +637,7 @@ def test_frequency_seasonality_coordinates(test_case):
627
637
628
638
# generate expected state names based on actual model name
629
639
expected_state_names = [
630
- f"{ f } _{ model_name } _{ i } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
640
+ f"{ f } _{ i } _{ model_name } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
631
641
][: test_case ["expected_shape" ][- 1 ]]
632
642
633
643
# assert coordinate structure
@@ -677,7 +687,7 @@ def test_frequency_seasonality_edge_cases(test_case):
677
687
678
688
# generate expected state names based on actual model name
679
689
expected_state_names = [
680
- f"{ f } _{ model_name } _{ i } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
690
+ f"{ f } _{ i } _{ model_name } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
681
691
][: test_case ["expected_shape" ][- 1 ]]
682
692
683
693
# assert coordinate structure
0 commit comments