11import  numpy  as  np 
22
33from  pytensor  import  tensor  as  pt 
4+ from  pytensor .tensor .slinalg  import  block_diag 
45from  scipy  import  linalg 
56
67from  pymc_extras .statespace .models .structural .core  import  Component 
@@ -96,7 +97,6 @@ class CycleComponent(Component):
9697
9798            cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"]) 
9899            cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12) 
99- 
100100            sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1) 
101101
102102            ss_mod.build_statespace_graph(data) 
@@ -124,13 +124,15 @@ class CycleComponent(Component):
124124        with pm.Model(coords=ss_mod.coords) as model: 
125125            P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"]) 
126126            # Initial states: shape (3, 2) for 3 variables, 2 states each 
127-             cycle_init = pm.Normal('business_cycle', dims=('business_cycle_endog', 'business_cycle_state') ) 
127+             cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"] ) 
128128
129129            # Dampening factor: scalar (shared across variables) 
130-             dampening = pm.Uniform(' business_cycle_dampening_factor', lower=0.8, upper=1.0 ) 
130+             dampening = pm.Beta(" business_cycle_dampening_factor", 2, 2 ) 
131131
132132            # Innovation variances: shape (3,) for variable-specific variances 
133-             sigma_cycle = pm.HalfNormal('sigma_business_cycle', dims=('business_cycle_endog',)) 
133+             sigma_cycle = pm.HalfNormal( 
134+                 "sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"] 
135+             ) 
134136
135137            ss_mod.build_statespace_graph(data) 
136138            idata = pm.sample() 
@@ -220,12 +222,8 @@ def make_symbolic_graph(self) -> None:
220222        if  self .k_endog  ==  1 :
221223            self .ssm ["transition" , :, :] =  T 
222224        else :
223-             # can't make the linalg.block_diag logic work here 
224-             # doing it manually for now 
225-             for  i  in  range (self .k_endog ):
226-                 start_idx  =  i  *  2 
227-                 end_idx  =  (i  +  1 ) *  2 
228-                 self .ssm ["transition" , start_idx :end_idx , start_idx :end_idx ] =  T 
225+             transition  =  block_diag (* [T  for  _  in  range (self .k_endog )])
226+             self .ssm ["transition" ] =  pt .specify_shape (transition , (self .k_states , self .k_states ))
229227
230228        if  self .innovations :
231229            if  self .k_endog  ==  1 :
@@ -235,16 +233,20 @@ def make_symbolic_graph(self) -> None:
235233                sigma_cycle  =  self .make_and_register_variable (
236234                    f"sigma_{ self .name }  , shape = (self .k_endog ,)
237235                )
238-                 # can't make the linalg.block_diag logic work here 
239-                 # doing it manually for now 
240-                 for  i  in  range (self .k_endog ):
241-                     start_idx  =  i  *  2 
242-                     end_idx  =  (i  +  1 ) *  2 
243-                     Q_block  =  pt .eye (2 ) *  sigma_cycle [i ] **  2 
244-                     self .ssm ["state_cov" , start_idx :end_idx , start_idx :end_idx ] =  Q_block 
236+                 state_cov  =  block_diag (
237+                     * [pt .eye (2 ) *  sigma_cycle [i ] **  2  for  i  in  range (self .k_endog )]
238+                 )
239+                 self .ssm ["state_cov" ] =  pt .specify_shape (state_cov , (self .k_states , self .k_states ))
245240
246241    def  populate_component_properties (self ):
247-         self .state_names  =  [f"{ self .name } { f }   for  f  in  ["Cos" , "Sin" ]]
242+         if  self .k_endog  ==  1 :
243+             self .state_names  =  [f"{ self .name } { f }   for  f  in  ["Cos" , "Sin" ]]
244+         else :
245+             # For multivariate cycles, create state names for each observed state 
246+             self .state_names  =  []
247+             for  var_name  in  self .observed_state_names :
248+                 self .state_names .extend ([f"{ self .name } { var_name } { f }   for  f  in  ["Cos" , "Sin" ]])
249+ 
248250        self .param_names  =  [f"{ self .name }  ]
249251
250252        if  self .k_endog  ==  1 :
@@ -260,7 +262,7 @@ def populate_component_properties(self):
260262        else :
261263            self .param_dims  =  {self .name : (f"{ self .name }  , f"{ self .name }  )}
262264            self .coords  =  {
263-                 f"{ self .name }  : self .state_names ,
265+                 f"{ self .name }  : [ f" { self .name } _Cos" ,  f" { self . name } _Sin" ] ,
264266                f"{ self .name }  : self .observed_state_names ,
265267            }
266268            self .param_info  =  {
0 commit comments