@@ -449,6 +449,11 @@ class FrequencySeasonality(Component):
449449 observed_state_names: list[str] | None, default None
450450 List of strings for observed state labels. If None, defaults to ["data"].
451451
452+ share_states: bool, default False
453+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
454+ states, which are observed by all observed states. If False, each observed state has its own set of
455+ latent states. This argument has no effect if `k_endog` is 1.
456+
452457 Notes
453458 -----
454459 A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -480,15 +485,17 @@ class FrequencySeasonality(Component):
480485
481486 def __init__ (
482487 self ,
483- season_length ,
484- n = None ,
485- name = None ,
486- innovations = True ,
488+ season_length : int ,
489+ n : int | None = None ,
490+ name : str | None = None ,
491+ innovations : bool = True ,
487492 observed_state_names : list [str ] | None = None ,
493+ share_states : bool = False ,
488494 ):
489495 if observed_state_names is None :
490496 observed_state_names = ["data" ]
491497
498+ self .share_states = share_states
492499 k_endog = len (observed_state_names )
493500
494501 if n is None :
@@ -504,18 +511,20 @@ def __init__(
504511 # If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
505512 # get a parameter assigned to it and should just be fixed to zero.
506513 # Test this way (rather than n == s // 2) to catch cases when n is non-integer.
507- self .last_state_not_identified = self .season_length / self .n == 2.0
514+ self .last_state_not_identified = ( self .season_length / self .n ) == 2.0
508515 self .n_coefs = k_states - int (self .last_state_not_identified )
509516
510517 obs_state_idx = np .zeros (k_states )
511518 obs_state_idx [slice (0 , k_states , 2 )] = 1
512- obs_state_idx = np .tile (obs_state_idx , k_endog )
519+ obs_state_idx = np .tile (obs_state_idx , 1 if share_states else k_endog )
513520
514521 super ().__init__ (
515522 name = name ,
516523 k_endog = k_endog ,
517- k_states = k_states * k_endog ,
518- k_posdef = k_states * int (self .innovations ) * k_endog ,
524+ k_states = k_states if share_states else k_states * k_endog ,
525+ k_posdef = k_states * int (self .innovations )
526+ if share_states
527+ else k_states * int (self .innovations ) * k_endog ,
519528 observed_state_names = observed_state_names ,
520529 measurement_error = False ,
521530 combine_hidden_states = True ,
@@ -524,13 +533,15 @@ def __init__(
524533
525534 def make_symbolic_graph (self ) -> None :
526535 k_endog = self .k_endog
527- k_states = self .k_states // k_endog
528- k_posdef = self .k_posdef // k_endog
536+ k_endog_effective = 1 if self .share_states else k_endog
537+
538+ k_states = self .k_states // k_endog_effective
539+ k_posdef = self .k_posdef // k_endog_effective
529540 n_coefs = self .n_coefs
530541
531542 Z = pt .zeros ((1 , k_states ))[0 , slice (0 , k_states , 2 )].set (1.0 )
532543
533- self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
544+ self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog_effective )])
534545
535546 init_state = self .make_and_register_variable (
536547 f"params_{ self .name } " , shape = (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs )
@@ -539,7 +550,7 @@ def make_symbolic_graph(self) -> None:
539550 init_state_idx = np .concatenate (
540551 [
541552 np .arange (k_states * i , (i + 1 ) * k_states , dtype = int )[:n_coefs ]
542- for i in range (k_endog )
553+ for i in range (k_endog_effective )
543554 ],
544555 axis = 0 ,
545556 )
@@ -548,11 +559,11 @@ def make_symbolic_graph(self) -> None:
548559
549560 T_mats = [_frequency_transition_block (self .season_length , j + 1 ) for j in range (self .n )]
550561 T = pt .linalg .block_diag (* T_mats )
551- self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog )])
562+ self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog_effective )])
552563
553564 if self .innovations :
554565 sigma_season = self .make_and_register_variable (
555- f"sigma_{ self .name } " , shape = () if k_endog == 1 else (k_endog ,)
566+ f"sigma_{ self .name } " , shape = () if k_endog_effective == 1 else (k_endog_effective ,)
556567 )
557568 self .ssm ["selection" , :, :] = pt .eye (self .k_states )
558569 self .ssm ["state_cov" , :, :] = pt .eye (self .k_posdef ) * pt .repeat (
@@ -561,35 +572,35 @@ def make_symbolic_graph(self) -> None:
561572
562573 def populate_component_properties (self ):
563574 k_endog = self .k_endog
575+ k_endog_effective = 1 if self .share_states else k_endog
564576 n_coefs = self .n_coefs
565577
566- self .state_names = [
567- f"{ f } _{ i } _{ self .name } [{ obs_state_name } ]"
568- for obs_state_name in self .observed_state_names
569- for i in range (self .n )
570- for f in ["Cos" , "Sin" ]
571- ]
572- # determine which state names correspond to parameters
573- # all endog variables use same state structure, so we just need
574- # the first n_coefs state names (which may be less than total if saturated)
575- param_state_names = [f"{ f } _{ i } _{ self .name } " for i in range (self .n ) for f in ["Cos" , "Sin" ]][
576- :n_coefs
577- ]
578+ base_names = [f"{ f } _{ i } _{ self .name } " for i in range (self .n ) for f in ["Cos" , "Sin" ]]
578579
579- self .param_names = [f"params_{ self .name } " ]
580+ if self .share_states :
581+ self .state_names = [f"{ name } [shared]" for name in base_names ]
582+ else :
583+ self .state_names = [
584+ f"{ name } [{ obs_state_name } ]"
585+ for obs_state_name in self .observed_state_names
586+ for name in base_names
587+ ]
580588
589+ # Trim state names if the model is saturated
590+ param_state_names = base_names [:n_coefs ]
591+
592+ self .param_names = [f"params_{ self .name } " ]
581593 self .param_dims = {
582594 f"params_{ self .name } " : (f"state_{ self .name } " ,)
583- if k_endog == 1
595+ if k_endog_effective == 1
584596 else (f"endog_{ self .name } " , f"state_{ self .name } " )
585597 }
586-
587598 self .param_info = {
588599 f"params_{ self .name } " : {
589- "shape" : (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs ),
600+ "shape" : (n_coefs ,) if k_endog_effective == 1 else (k_endog_effective , n_coefs ),
590601 "constraints" : None ,
591602 "dims" : (f"state_{ self .name } " ,)
592- if k_endog == 1
603+ if k_endog_effective == 1
593604 else (f"endog_{ self .name } " , f"state_{ self .name } " ),
594605 }
595606 }
@@ -607,9 +618,9 @@ def populate_component_properties(self):
607618 self .param_names += [f"sigma_{ self .name } " ]
608619 self .shock_names = self .state_names .copy ()
609620 self .param_info [f"sigma_{ self .name } " ] = {
610- "shape" : () if k_endog == 1 else (k_endog , ),
621+ "shape" : () if k_endog_effective == 1 else (k_endog_effective , n_coefs ),
611622 "constraints" : "Positive" ,
612- "dims" : None if k_endog == 1 else (f"endog_{ self .name } " ,),
623+ "dims" : None if k_endog_effective == 1 else (f"endog_{ self .name } " ,),
613624 }
614- if k_endog > 1 :
625+ if k_endog_effective > 1 :
615626 self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
0 commit comments