11import numpy as np
2+ import pytensor
23import pytest
34
45from pytensor import config
6+ from pytensor .graph .basic import explicit_graph_inputs
57
68from pymc_extras .statespace .models import structural as st
79from tests .statespace .models .structural .conftest import _assert_basic_coords_correct
@@ -35,7 +37,7 @@ def random_word(rng):
3537 x0 [0 ] = 1
3638
3739 params = {"season_coefs" : x0 }
38- if mod . innovations :
40+ if innovations :
3941 params ["sigma_season" ] = 0.0
4042
4143 x , y = simulate_from_numpy_model (mod , rng , params )
@@ -44,12 +46,175 @@ def random_word(rng):
4446 assert_pattern_repeats (y , s , atol = ATOL , rtol = RTOL )
4547
4648 # Check coords
47- mod .build (verbose = False )
49+ mod = mod .build (verbose = False )
4850 _assert_basic_coords_correct (mod )
4951 test_slice = slice (1 , None ) if remove_first_state else slice (None )
5052 assert mod .coords ["season_state" ] == state_names [test_slice ]
5153
5254
55+ @pytest .mark .parametrize (
56+ "remove_first_state" , [True , False ], ids = ["remove_first_state" , "keep_first_state" ]
57+ )
58+ def test_time_seasonality_multiple_observed (rng , remove_first_state ):
59+ s = 3
60+ state_names = [f"state_{ i } " for i in range (s )]
61+ mod = st .TimeSeasonality (
62+ season_length = s ,
63+ innovations = True ,
64+ name = "season" ,
65+ state_names = state_names ,
66+ observed_state_names = ["data_1" , "data_2" ],
67+ remove_first_state = remove_first_state ,
68+ )
69+ x0 = np .zeros ((mod .k_endog , mod .k_states // mod .k_endog ), dtype = config .floatX )
70+
71+ expected_states = [
72+ f"state_{ i } [data_{ j } ]" for j in range (1 , 3 ) for i in range (int (remove_first_state ), s )
73+ ]
74+ assert mod .state_names == expected_states
75+ assert mod .shock_names == ["season[data_1]" , "season[data_2]" ]
76+
77+ x0 [0 , 0 ] = 1
78+ x0 [1 , 0 ] = 2.0
79+
80+ params = {"season_coefs" : x0 , "sigma_season" : np .array ([0.0 , 0.0 ], dtype = config .floatX )}
81+
82+ x , y = simulate_from_numpy_model (mod , rng , params , steps = 123 )
83+ assert_pattern_repeats (y [:, 0 ], s , atol = ATOL , rtol = RTOL )
84+ assert_pattern_repeats (y [:, 1 ], s , atol = ATOL , rtol = RTOL )
85+
86+ mod = mod .build (verbose = False )
87+ x0 , * _ , T , Z , R , _ , Q = mod ._unpack_statespace_with_placeholders ()
88+
89+ input_vars = explicit_graph_inputs ([x0 , T , Z , R , Q ])
90+
91+ fn = pytensor .function (
92+ inputs = list (input_vars ),
93+ outputs = [x0 , T , Z , R , Q ],
94+ mode = "FAST_COMPILE" ,
95+ )
96+
97+ params ["sigma_season" ] = np .array ([0.1 , 0.8 ], dtype = config .floatX )
98+ x0 , T , Z , R , Q = fn (** params )
99+
100+ if remove_first_state :
101+ expected_x0 = np .array ([1.0 , 0.0 , 2.0 , 0.0 ])
102+
103+ expected_T = np .array (
104+ [
105+ [- 1.0 , - 1.0 , 0.0 , 0.0 ],
106+ [1.0 , 0.0 , 0.0 , 0.0 ],
107+ [0.0 , 0.0 , - 1.0 , - 1.0 ],
108+ [0.0 , 0.0 , 1.0 , 0.0 ],
109+ ]
110+ )
111+ expected_R = np .array ([[1.0 , 1.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ], [0.0 , 0.0 ]])
112+ expected_Z = np .array ([[1.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 1.0 , 0.0 ]])
113+
114+ else :
115+ expected_x0 = np .array ([1.0 , 0.0 , 0.0 , 2.0 , 0.0 , 0.0 ])
116+ expected_T = np .array (
117+ [
118+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
119+ [0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
120+ [1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
121+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
122+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
123+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ],
124+ ]
125+ )
126+ expected_R = np .array (
127+ [[1.0 , 1.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]]
128+ )
129+ expected_Z = np .array ([[1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ]])
130+
131+ expected_Q = np .array ([[0.1 ** 2 , 0.0 ], [0.0 , 0.8 ** 2 ]])
132+
133+ for matrix , expected in zip (
134+ [x0 , T , Z , R , Q ],
135+ [expected_x0 , expected_T , expected_Z , expected_R , expected_Q ],
136+ ):
137+ np .testing .assert_allclose (matrix , expected )
138+
139+
140+ def test_add_two_time_seasonality_different_observed (rng ):
141+ mod1 = st .TimeSeasonality (
142+ season_length = 3 ,
143+ innovations = True ,
144+ name = "season1" ,
145+ state_names = [f"state_{ i } " for i in range (3 )],
146+ observed_state_names = ["data_1" ],
147+ remove_first_state = False ,
148+ )
149+ mod2 = st .TimeSeasonality (
150+ season_length = 5 ,
151+ innovations = True ,
152+ name = "season2" ,
153+ state_names = [f"state_{ i } " for i in range (5 )],
154+ observed_state_names = ["data_2" ],
155+ )
156+
157+ mod = (mod1 + mod2 ).build (verbose = False )
158+
159+ params = {
160+ "season1_coefs" : np .array ([1.0 , 0.0 , 0.0 ], dtype = config .floatX ),
161+ "season2_coefs" : np .array ([3.0 , 0.0 , 0.0 , 0.0 ], dtype = config .floatX ),
162+ "sigma_season1" : np .array (0.0 , dtype = config .floatX ),
163+ "sigma_season2" : np .array (0.0 , dtype = config .floatX ),
164+ "initial_state_cov" : np .eye (mod .k_states , dtype = config .floatX ),
165+ }
166+
167+ x , y = simulate_from_numpy_model (mod , rng , params , steps = 3 * 5 * 5 )
168+ assert_pattern_repeats (y [:, 0 ], 3 , atol = ATOL , rtol = RTOL )
169+ assert_pattern_repeats (y [:, 1 ], 5 , atol = ATOL , rtol = RTOL )
170+
171+ assert mod .state_names == [
172+ "state_0[data_1]" ,
173+ "state_1[data_1]" ,
174+ "state_2[data_1]" ,
175+ "state_1[data_2]" ,
176+ "state_2[data_2]" ,
177+ "state_3[data_2]" ,
178+ "state_4[data_2]" ,
179+ ]
180+
181+ assert mod .shock_names == ["season1[data_1]" , "season2[data_2]" ]
182+
183+ x0 , * _ , T = mod ._unpack_statespace_with_placeholders ()[:5 ]
184+ input_vars = explicit_graph_inputs ([x0 , T ])
185+ fn = pytensor .function (
186+ inputs = list (input_vars ),
187+ outputs = [x0 , T ],
188+ mode = "FAST_COMPILE" ,
189+ )
190+
191+ x0 , T = fn (
192+ season1_coefs = np .array ([1.0 , 0.0 , 0.0 ], dtype = config .floatX ),
193+ season2_coefs = np .array ([3.0 , 0.0 , 0.0 , 1.2 ], dtype = config .floatX ),
194+ )
195+
196+ np .testing .assert_allclose (
197+ np .array ([1.0 , 0.0 , 0.0 , 3.0 , 0.0 , 0.0 , 1.2 ]), x0 , atol = ATOL , rtol = RTOL
198+ )
199+
200+ np .testing .assert_allclose (
201+ np .array (
202+ [
203+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
204+ [0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
205+ [1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
206+ [0.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 ],
207+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
208+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ],
209+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
210+ ]
211+ ),
212+ T ,
213+ atol = ATOL ,
214+ rtol = RTOL ,
215+ )
216+
217+
53218def get_shift_factor (s ):
54219 s_str = str (s )
55220 if "." not in s_str :
0 commit comments