1
1
import numpy as np
2
+ import pytensor
2
3
import pytest
3
4
4
5
from pytensor import config
6
+ from pytensor .graph .basic import explicit_graph_inputs
5
7
6
8
from pymc_extras .statespace .models import structural as st
7
9
from tests .statespace .models .structural .conftest import _assert_basic_coords_correct
@@ -35,7 +37,7 @@ def random_word(rng):
35
37
x0 [0 ] = 1
36
38
37
39
params = {"season_coefs" : x0 }
38
- if mod . innovations :
40
+ if innovations :
39
41
params ["sigma_season" ] = 0.0
40
42
41
43
x , y = simulate_from_numpy_model (mod , rng , params )
@@ -44,12 +46,175 @@ def random_word(rng):
44
46
assert_pattern_repeats (y , s , atol = ATOL , rtol = RTOL )
45
47
46
48
# Check coords
47
- mod .build (verbose = False )
49
+ mod = mod .build (verbose = False )
48
50
_assert_basic_coords_correct (mod )
49
51
test_slice = slice (1 , None ) if remove_first_state else slice (None )
50
52
assert mod .coords ["season_state" ] == state_names [test_slice ]
51
53
52
54
55
+ @pytest .mark .parametrize (
56
+ "remove_first_state" , [True , False ], ids = ["remove_first_state" , "keep_first_state" ]
57
+ )
58
+ def test_time_seasonality_multiple_observed (rng , remove_first_state ):
59
+ s = 3
60
+ state_names = [f"state_{ i } " for i in range (s )]
61
+ mod = st .TimeSeasonality (
62
+ season_length = s ,
63
+ innovations = True ,
64
+ name = "season" ,
65
+ state_names = state_names ,
66
+ observed_state_names = ["data_1" , "data_2" ],
67
+ remove_first_state = remove_first_state ,
68
+ )
69
+ x0 = np .zeros ((mod .k_endog , mod .k_states // mod .k_endog ), dtype = config .floatX )
70
+
71
+ expected_states = [
72
+ f"state_{ i } [data_{ j } ]" for j in range (1 , 3 ) for i in range (int (remove_first_state ), s )
73
+ ]
74
+ assert mod .state_names == expected_states
75
+ assert mod .shock_names == ["season[data_1]" , "season[data_2]" ]
76
+
77
+ x0 [0 , 0 ] = 1
78
+ x0 [1 , 0 ] = 2.0
79
+
80
+ params = {"season_coefs" : x0 , "sigma_season" : np .array ([0.0 , 0.0 ], dtype = config .floatX )}
81
+
82
+ x , y = simulate_from_numpy_model (mod , rng , params , steps = 123 )
83
+ assert_pattern_repeats (y [:, 0 ], s , atol = ATOL , rtol = RTOL )
84
+ assert_pattern_repeats (y [:, 1 ], s , atol = ATOL , rtol = RTOL )
85
+
86
+ mod = mod .build (verbose = False )
87
+ x0 , * _ , T , Z , R , _ , Q = mod ._unpack_statespace_with_placeholders ()
88
+
89
+ input_vars = explicit_graph_inputs ([x0 , T , Z , R , Q ])
90
+
91
+ fn = pytensor .function (
92
+ inputs = list (input_vars ),
93
+ outputs = [x0 , T , Z , R , Q ],
94
+ mode = "FAST_COMPILE" ,
95
+ )
96
+
97
+ params ["sigma_season" ] = np .array ([0.1 , 0.8 ], dtype = config .floatX )
98
+ x0 , T , Z , R , Q = fn (** params )
99
+
100
+ if remove_first_state :
101
+ expected_x0 = np .array ([1.0 , 0.0 , 2.0 , 0.0 ])
102
+
103
+ expected_T = np .array (
104
+ [
105
+ [- 1.0 , - 1.0 , 0.0 , 0.0 ],
106
+ [1.0 , 0.0 , 0.0 , 0.0 ],
107
+ [0.0 , 0.0 , - 1.0 , - 1.0 ],
108
+ [0.0 , 0.0 , 1.0 , 0.0 ],
109
+ ]
110
+ )
111
+ expected_R = np .array ([[1.0 , 1.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ], [0.0 , 0.0 ]])
112
+ expected_Z = np .array ([[1.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 1.0 , 0.0 ]])
113
+
114
+ else :
115
+ expected_x0 = np .array ([1.0 , 0.0 , 0.0 , 2.0 , 0.0 , 0.0 ])
116
+ expected_T = np .array (
117
+ [
118
+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
119
+ [0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
120
+ [1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
121
+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
122
+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
123
+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ],
124
+ ]
125
+ )
126
+ expected_R = np .array (
127
+ [[1.0 , 1.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]]
128
+ )
129
+ expected_Z = np .array ([[1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ]])
130
+
131
+ expected_Q = np .array ([[0.1 ** 2 , 0.0 ], [0.0 , 0.8 ** 2 ]])
132
+
133
+ for matrix , expected in zip (
134
+ [x0 , T , Z , R , Q ],
135
+ [expected_x0 , expected_T , expected_Z , expected_R , expected_Q ],
136
+ ):
137
+ np .testing .assert_allclose (matrix , expected )
138
+
139
+
140
+ def test_add_two_time_seasonality_different_observed (rng ):
141
+ mod1 = st .TimeSeasonality (
142
+ season_length = 3 ,
143
+ innovations = True ,
144
+ name = "season1" ,
145
+ state_names = [f"state_{ i } " for i in range (3 )],
146
+ observed_state_names = ["data_1" ],
147
+ remove_first_state = False ,
148
+ )
149
+ mod2 = st .TimeSeasonality (
150
+ season_length = 5 ,
151
+ innovations = True ,
152
+ name = "season2" ,
153
+ state_names = [f"state_{ i } " for i in range (5 )],
154
+ observed_state_names = ["data_2" ],
155
+ )
156
+
157
+ mod = (mod1 + mod2 ).build (verbose = False )
158
+
159
+ params = {
160
+ "season1_coefs" : np .array ([1.0 , 0.0 , 0.0 ], dtype = config .floatX ),
161
+ "season2_coefs" : np .array ([3.0 , 0.0 , 0.0 , 0.0 ], dtype = config .floatX ),
162
+ "sigma_season1" : np .array (0.0 , dtype = config .floatX ),
163
+ "sigma_season2" : np .array (0.0 , dtype = config .floatX ),
164
+ "initial_state_cov" : np .eye (mod .k_states , dtype = config .floatX ),
165
+ }
166
+
167
+ x , y = simulate_from_numpy_model (mod , rng , params , steps = 3 * 5 * 5 )
168
+ assert_pattern_repeats (y [:, 0 ], 3 , atol = ATOL , rtol = RTOL )
169
+ assert_pattern_repeats (y [:, 1 ], 5 , atol = ATOL , rtol = RTOL )
170
+
171
+ assert mod .state_names == [
172
+ "state_0[data_1]" ,
173
+ "state_1[data_1]" ,
174
+ "state_2[data_1]" ,
175
+ "state_1[data_2]" ,
176
+ "state_2[data_2]" ,
177
+ "state_3[data_2]" ,
178
+ "state_4[data_2]" ,
179
+ ]
180
+
181
+ assert mod .shock_names == ["season1[data_1]" , "season2[data_2]" ]
182
+
183
+ x0 , * _ , T = mod ._unpack_statespace_with_placeholders ()[:5 ]
184
+ input_vars = explicit_graph_inputs ([x0 , T ])
185
+ fn = pytensor .function (
186
+ inputs = list (input_vars ),
187
+ outputs = [x0 , T ],
188
+ mode = "FAST_COMPILE" ,
189
+ )
190
+
191
+ x0 , T = fn (
192
+ season1_coefs = np .array ([1.0 , 0.0 , 0.0 ], dtype = config .floatX ),
193
+ season2_coefs = np .array ([3.0 , 0.0 , 0.0 , 1.2 ], dtype = config .floatX ),
194
+ )
195
+
196
+ np .testing .assert_allclose (
197
+ np .array ([1.0 , 0.0 , 0.0 , 3.0 , 0.0 , 0.0 , 1.2 ]), x0 , atol = ATOL , rtol = RTOL
198
+ )
199
+
200
+ np .testing .assert_allclose (
201
+ np .array (
202
+ [
203
+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
204
+ [0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
205
+ [1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
206
+ [0.0 , 0.0 , 0.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 ],
207
+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
208
+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ],
209
+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
210
+ ]
211
+ ),
212
+ T ,
213
+ atol = ATOL ,
214
+ rtol = RTOL ,
215
+ )
216
+
217
+
53
218
def get_shift_factor (s ):
54
219
s_str = str (s )
55
220
if "." not in s_str :
0 commit comments