2
2
3
3
from pytensor import tensor as pt
4
4
from pytensor .tensor .slinalg import block_diag
5
- from scipy import linalg
6
5
7
6
from pymc_extras .statespace .models .structural .core import Component
8
7
from pymc_extras .statespace .models .structural .utils import _frequency_transition_block
@@ -190,22 +189,17 @@ def __init__(
190
189
)
191
190
192
191
def make_symbolic_graph (self ) -> None :
193
- if self .k_endog == 1 :
194
- self .ssm ["design" , 0 , slice (0 , self .k_states , 2 )] = 1
195
- self .ssm ["selection" , :, :] = np .eye (self .k_states )
196
- init_state = self .make_and_register_variable (f"{ self .name } " , shape = (self .k_states ,))
197
-
198
- else :
199
- Z = np .array ([1.0 , 0.0 ]).reshape ((1 , - 1 ))
200
- design_matrix = linalg .block_diag (* [Z for _ in range (self .k_endog )])
201
- self .ssm ["design" , :, :] = pt .as_tensor_variable (design_matrix )
192
+ Z = np .array ([1.0 , 0.0 ]).reshape ((1 , - 1 ))
193
+ design_matrix = block_diag (* [Z for _ in range (self .k_endog )])
194
+ self .ssm ["design" , :, :] = pt .as_tensor_variable (design_matrix )
202
195
203
- R = np .eye (2 ) # 2x2 identity for each cycle component
204
- selection_matrix = linalg .block_diag (* [R for _ in range (self .k_endog )])
205
- self .ssm ["selection" , :, :] = pt .as_tensor_variable (selection_matrix )
206
-
207
- init_state = self .make_and_register_variable (f"{ self .name } " , shape = (self .k_endog , 2 ))
196
+ R = np .eye (2 ) # 2x2 identity for each cycle component
197
+ selection_matrix = block_diag (* [R for _ in range (self .k_endog )])
198
+ self .ssm ["selection" , :, :] = pt .as_tensor_variable (selection_matrix )
208
199
200
+ init_state = self .make_and_register_variable (
201
+ f"{ self .name } " , shape = (self .k_endog , 2 ) if self .k_endog > 1 else (self .k_states ,)
202
+ )
209
203
self .ssm ["initial_state" , :] = init_state .ravel ()
210
204
211
205
if self .estimate_cycle_length :
@@ -219,11 +213,8 @@ def make_symbolic_graph(self) -> None:
219
213
rho = 1
220
214
221
215
T = rho * _frequency_transition_block (lamb , j = 1 )
222
- if self .k_endog == 1 :
223
- self .ssm ["transition" , :, :] = T
224
- else :
225
- transition = block_diag (* [T for _ in range (self .k_endog )])
226
- self .ssm ["transition" ] = pt .specify_shape (transition , (self .k_states , self .k_states ))
216
+ transition = block_diag (* [T for _ in range (self .k_endog )])
217
+ self .ssm ["transition" ] = pt .specify_shape (transition , (self .k_states , self .k_states ))
227
218
228
219
if self .innovations :
229
220
if self .k_endog == 1 :
@@ -239,13 +230,11 @@ def make_symbolic_graph(self) -> None:
239
230
self .ssm ["state_cov" ] = pt .specify_shape (state_cov , (self .k_states , self .k_states ))
240
231
241
232
def populate_component_properties (self ):
242
- if self .k_endog == 1 :
243
- self .state_names = [f"{ self .name } _{ f } " for f in ["Cos" , "Sin" ]]
244
- else :
245
- # For multivariate cycles, create state names for each observed state
246
- self .state_names = []
247
- for var_name in self .observed_state_names :
248
- self .state_names .extend ([f"{ self .name } _{ var_name } _{ f } " for f in ["Cos" , "Sin" ]])
233
+ self .state_names = [
234
+ f"{ self .name } _{ f } [{ var_name } ]" if self .k_endog > 1 else f"{ self .name } _{ f } "
235
+ for var_name in self .observed_state_names
236
+ for f in ["Cos" , "Sin" ]
237
+ ]
249
238
250
239
self .param_names = [f"{ self .name } " ]
251
240
@@ -276,17 +265,17 @@ def populate_component_properties(self):
276
265
if self .estimate_cycle_length :
277
266
self .param_names += [f"{ self .name } _length" ]
278
267
self .param_info [f"{ self .name } _length" ] = {
279
- "shape" : (),
268
+ "shape" : () if self . k_endog == 1 else ( self . k_endog ,) ,
280
269
"constraints" : "Positive, non-zero" ,
281
- "dims" : None ,
270
+ "dims" : None if self . k_endog == 1 else f" { self . name } _endog" ,
282
271
}
283
272
284
273
if self .dampen :
285
274
self .param_names += [f"{ self .name } _dampening_factor" ]
286
275
self .param_info [f"{ self .name } _dampening_factor" ] = {
287
- "shape" : (),
276
+ "shape" : () if self . k_endog == 1 else ( self . k_endog ,) ,
288
277
"constraints" : "0 < x ≤ 1" ,
289
- "dims" : None ,
278
+ "dims" : None if self . k_endog == 1 else f" { self . name } _endog" ,
290
279
}
291
280
292
281
if self .innovations :
0 commit comments