@@ -326,7 +326,7 @@ def test_frequency_seasonality(n, s, rng):
326326 _assert_basic_coords_correct (mod )
327327 if n is None :
328328 n = int (s // 2 )
329- states = [f"{ f } _season_ { i } " for i in range (n ) for f in ["Cos" , "Sin" ]]
329+ states = [f"{ f } _ { i } _season " for i in range (n ) for f in ["Cos" , "Sin" ]]
330330
331331 # remove last state when model is completely saturated
332332 if s / n == 2.0 :
@@ -345,19 +345,25 @@ def test_frequency_seasonality_multiple_observed(rng):
345345 observed_state_names = observed_state_names ,
346346 )
347347 expected_state_names = [
348- "Cos_season_0 [data_1]" ,
349- "Sin_season_0 [data_1]" ,
350- "Cos_season_1 [data_1]" ,
351- "Sin_season_1 [data_1]" ,
352- "Cos_season_0 [data_2]" ,
353- "Sin_season_0 [data_2]" ,
354- "Cos_season_1 [data_2]" ,
355- "Sin_season_1 [data_2]" ,
348+ "Cos_0_season [data_1]" ,
349+ "Sin_0_season [data_1]" ,
350+ "Cos_1_season [data_1]" ,
351+ "Sin_1_season [data_1]" ,
352+ "Cos_0_season [data_2]" ,
353+ "Sin_0_season [data_2]" ,
354+ "Cos_1_season [data_2]" ,
355+ "Sin_1_season [data_2]" ,
356356 ]
357357 assert mod .state_names == expected_state_names
358358 assert mod .shock_names == [
359- "season[data_1]" ,
360- "season[data_2]" ,
359+ "Cos_0_season[data_1]" ,
360+ "Sin_0_season[data_1]" ,
361+ "Cos_1_season[data_1]" ,
362+ "Sin_1_season[data_1]" ,
363+ "Cos_0_season[data_2]" ,
364+ "Sin_0_season[data_2]" ,
365+ "Cos_1_season[data_2]" ,
366+ "Sin_1_season[data_2]" ,
361367 ]
362368
363369 x0 = np .zeros ((2 , 3 ), dtype = config .floatX )
@@ -372,9 +378,9 @@ def test_frequency_seasonality_multiple_observed(rng):
372378
373379 mod = mod .build (verbose = False )
374380 assert list (mod .coords ["state_season" ]) == [
375- "Cos_season_0 " ,
376- "Sin_season_0 " ,
377- "Cos_season_1 " ,
381+ "Cos_0_season " ,
382+ "Sin_0_season " ,
383+ "Cos_1_season " ,
378384 ]
379385
380386 x0_sym , * _ , T_sym , Z_sym , R_sym , _ , Q_sym = mod ._unpack_statespace_with_placeholders ()
@@ -460,17 +466,21 @@ def test_add_two_frequency_seasonality_different_observed(rng):
460466 assert_pattern_repeats (y [:, 1 ], 6 , atol = ATOL , rtol = RTOL )
461467
462468 assert mod .state_names == [
463- "Cos_freq1_0 [data_1]" ,
464- "Sin_freq1_0 [data_1]" ,
465- "Cos_freq1_1 [data_1]" ,
466- "Sin_freq1_1 [data_1]" ,
467- "Cos_freq2_0 [data_2]" ,
468- "Sin_freq2_0 [data_2]" ,
469+ "Cos_0_freq1 [data_1]" ,
470+ "Sin_0_freq1 [data_1]" ,
471+ "Cos_1_freq1 [data_1]" ,
472+ "Sin_1_freq1 [data_1]" ,
473+ "Cos_0_freq2 [data_2]" ,
474+ "Sin_0_freq2 [data_2]" ,
469475 ]
470476
471477 assert mod .shock_names == [
472- "freq1[data_1]" ,
473- "freq2[data_2]" ,
478+ "Cos_0_freq1[data_1]" ,
479+ "Sin_0_freq1[data_1]" ,
480+ "Cos_1_freq1[data_1]" ,
481+ "Sin_1_freq1[data_1]" ,
482+ "Cos_0_freq2[data_2]" ,
483+ "Sin_0_freq2[data_2]" ,
474484 ]
475485
476486 x0 , * _ , T = mod ._unpack_statespace_with_placeholders ()[:5 ]
@@ -627,7 +637,7 @@ def test_frequency_seasonality_coordinates(test_case):
627637
628638 # generate expected state names based on actual model name
629639 expected_state_names = [
630- f"{ f } _{ model_name } _{ i } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
640+ f"{ f } _{ i } _{ model_name } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
631641 ][: test_case ["expected_shape" ][- 1 ]]
632642
633643 # assert coordinate structure
@@ -677,7 +687,7 @@ def test_frequency_seasonality_edge_cases(test_case):
677687
678688 # generate expected state names based on actual model name
679689 expected_state_names = [
680- f"{ f } _{ model_name } _{ i } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
690+ f"{ f } _{ i } _{ model_name } " for i in range (test_case ["n" ]) for f in ["Cos" , "Sin" ]
681691 ][: test_case ["expected_shape" ][- 1 ]]
682692
683693 # assert coordinate structure
0 commit comments