1
1
import numpy as np
2
2
3
3
from pytensor import tensor as pt
4
+ from pytensor .tensor .slinalg import block_diag
4
5
from scipy import linalg
5
6
6
7
from pymc_extras .statespace .models .structural .core import Component
@@ -96,7 +97,6 @@ class CycleComponent(Component):
96
97
97
98
cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
98
99
cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
99
-
100
100
sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
101
101
102
102
ss_mod.build_statespace_graph(data)
@@ -124,13 +124,15 @@ class CycleComponent(Component):
124
124
with pm.Model(coords=ss_mod.coords) as model:
125
125
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"])
126
126
# Initial states: shape (3, 2) for 3 variables, 2 states each
127
- cycle_init = pm.Normal('business_cycle', dims=('business_cycle_endog', 'business_cycle_state') )
127
+ cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"] )
128
128
129
129
# Dampening factor: scalar (shared across variables)
130
- dampening = pm.Uniform(' business_cycle_dampening_factor', lower=0.8, upper=1.0 )
130
+ dampening = pm.Beta(" business_cycle_dampening_factor", 2, 2 )
131
131
132
132
# Innovation variances: shape (3,) for variable-specific variances
133
- sigma_cycle = pm.HalfNormal('sigma_business_cycle', dims=('business_cycle_endog',))
133
+ sigma_cycle = pm.HalfNormal(
134
+ "sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"]
135
+ )
134
136
135
137
ss_mod.build_statespace_graph(data)
136
138
idata = pm.sample()
@@ -220,12 +222,8 @@ def make_symbolic_graph(self) -> None:
220
222
if self .k_endog == 1 :
221
223
self .ssm ["transition" , :, :] = T
222
224
else :
223
- # can't make the linalg.block_diag logic work here
224
- # doing it manually for now
225
- for i in range (self .k_endog ):
226
- start_idx = i * 2
227
- end_idx = (i + 1 ) * 2
228
- self .ssm ["transition" , start_idx :end_idx , start_idx :end_idx ] = T
225
+ transition = block_diag (* [T for _ in range (self .k_endog )])
226
+ self .ssm ["transition" ] = pt .specify_shape (transition , (self .k_states , self .k_states ))
229
227
230
228
if self .innovations :
231
229
if self .k_endog == 1 :
@@ -235,16 +233,20 @@ def make_symbolic_graph(self) -> None:
235
233
sigma_cycle = self .make_and_register_variable (
236
234
f"sigma_{ self .name } " , shape = (self .k_endog ,)
237
235
)
238
- # can't make the linalg.block_diag logic work here
239
- # doing it manually for now
240
- for i in range (self .k_endog ):
241
- start_idx = i * 2
242
- end_idx = (i + 1 ) * 2
243
- Q_block = pt .eye (2 ) * sigma_cycle [i ] ** 2
244
- self .ssm ["state_cov" , start_idx :end_idx , start_idx :end_idx ] = Q_block
236
+ state_cov = block_diag (
237
+ * [pt .eye (2 ) * sigma_cycle [i ] ** 2 for i in range (self .k_endog )]
238
+ )
239
+ self .ssm ["state_cov" ] = pt .specify_shape (state_cov , (self .k_states , self .k_states ))
245
240
246
241
def populate_component_properties (self ):
247
- self .state_names = [f"{ self .name } _{ f } " for f in ["Cos" , "Sin" ]]
242
+ if self .k_endog == 1 :
243
+ self .state_names = [f"{ self .name } _{ f } " for f in ["Cos" , "Sin" ]]
244
+ else :
245
+ # For multivariate cycles, create state names for each observed state
246
+ self .state_names = []
247
+ for var_name in self .observed_state_names :
248
+ self .state_names .extend ([f"{ self .name } _{ var_name } _{ f } " for f in ["Cos" , "Sin" ]])
249
+
248
250
self .param_names = [f"{ self .name } " ]
249
251
250
252
if self .k_endog == 1 :
@@ -260,7 +262,7 @@ def populate_component_properties(self):
260
262
else :
261
263
self .param_dims = {self .name : (f"{ self .name } _endog" , f"{ self .name } _state" )}
262
264
self .coords = {
263
- f"{ self .name } _state" : self .state_names ,
265
+ f"{ self .name } _state" : [ f" { self .name } _Cos" , f" { self . name } _Sin" ] ,
264
266
f"{ self .name } _endog" : self .observed_state_names ,
265
267
}
266
268
self .param_info = {
0 commit comments