@@ -44,6 +44,11 @@ class TimeSeasonality(Component):
4444 observed_state_names: list[str] | None, default None
4545 List of strings for observed state labels. If None, defaults to ["data"].
4646
47+ share_states: bool, default False
48+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
49+ states, which are observed by all observed states. If False, each observed state has its own set of
50+ latent states. This argument has no effect if `k_endog` is 1.
51+
4752 Notes
4853 -----
4954 A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
@@ -235,6 +240,7 @@ def __init__(
235240 state_names : list | None = None ,
236241 remove_first_state : bool = True ,
237242 observed_state_names : list [str ] | None = None ,
243+ share_states : bool = False ,
238244 ):
239245 if observed_state_names is None :
240246 observed_state_names = ["data" ]
@@ -261,6 +267,7 @@ def __init__(
261267 )
262268 state_names = state_names .copy ()
263269
270+ self .share_states = share_states
264271 self .innovations = innovations
265272 self .duration = duration
266273 self .remove_first_state = remove_first_state
@@ -281,44 +288,53 @@ def __init__(
281288 super ().__init__ (
282289 name = name ,
283290 k_endog = k_endog ,
284- k_states = k_states * k_endog ,
285- k_posdef = k_posdef * k_endog ,
291+ k_states = k_states if share_states else k_states * k_endog ,
292+ k_posdef = k_posdef if share_states else k_posdef * k_endog ,
286293 observed_state_names = observed_state_names ,
287294 measurement_error = False ,
288295 combine_hidden_states = True ,
289- obs_state_idxs = np .tile (np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), k_endog ),
296+ obs_state_idxs = np .tile (
297+ np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), 1 if share_states else k_endog
298+ ),
290299 )
291300
292301 def populate_component_properties (self ):
293- k_states = self .k_states // self .k_endog
294302 k_endog = self .k_endog
303+ k_endog_effective = 1 if self .share_states else k_endog
295304
296- self .state_names = [
297- f"{ state_name } [{ endog_name } ]"
298- for endog_name in self .observed_state_names
299- for state_name in self .provided_state_names
300- ]
305+ k_states = self .k_states // k_endog_effective
306+
307+ if self .share_states :
308+ self .state_names = [
309+ f"{ state_name } [{ self .name } _shared]" for state_name in self .provided_state_names
310+ ]
311+ else :
312+ self .state_names = [
313+ f"{ state_name } [{ endog_name } ]"
314+ for endog_name in self .observed_state_names
315+ for state_name in self .provided_state_names
316+ ]
301317 self .param_names = [f"coefs_{ self .name } " ]
302318
303319 self .param_info = {
304320 f"coefs_{ self .name } " : {
305- "shape" : (k_states ,) if k_endog == 1 else (k_endog , k_states ),
321+ "shape" : (k_states ,) if k_endog_effective == 1 else (k_endog_effective , k_states ),
306322 "constraints" : None ,
307323 "dims" : (f"state_{ self .name } " ,)
308- if k_endog == 1
324+ if k_endog_effective == 1
309325 else (f"endog_{ self .name } " , f"state_{ self .name } " ),
310326 }
311327 }
312328
313329 self .param_dims = {
314330 f"coefs_{ self .name } " : (f"state_{ self .name } " ,)
315- if k_endog == 1
331+ if k_endog_effective == 1
316332 else (f"endog_{ self .name } " , f"state_{ self .name } " )
317333 }
318334
319335 self .coords = (
320336 {f"state_{ self .name } " : self .provided_state_names }
321- if k_endog == 1
337+ if k_endog_effective == 1
322338 else {
323339 f"endog_{ self .name } " : self .observed_state_names ,
324340 f"state_{ self .name } " : self .provided_state_names ,
@@ -332,14 +348,19 @@ def populate_component_properties(self):
332348 "constraints" : "Positive" ,
333349 "dims" : None ,
334350 }
335- self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
351+ if self .share_states :
352+ self .shock_names = [f"{ self .name } [shared]" ]
353+ else :
354+ self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
336355
337356 def make_symbolic_graph (self ) -> None :
338- k_states = self .k_states // self .k_endog
357+ k_endog = self .k_endog
358+ k_endog_effective = 1 if self .share_states else k_endog
359+ k_states = self .k_states // k_endog_effective
339360 duration = self .duration
361+
340362 k_unique_states = k_states // duration
341- k_posdef = self .k_posdef // self .k_endog
342- k_endog = self .k_endog
363+ k_posdef = self .k_posdef // k_endog_effective
343364
344365 if self .remove_first_state :
345366 # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
@@ -371,16 +392,18 @@ def make_symbolic_graph(self) -> None:
371392 T = pt .eye (k_states , k = 1 )
372393 T = pt .set_subtensor (T [- 1 , 0 ], 1 )
373394
374- self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog )])
395+ self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog_effective )])
375396
376397 Z = pt .zeros ((1 , k_states ))[0 , 0 ].set (1 )
377- self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
398+ self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog_effective )])
378399
379400 initial_states = self .make_and_register_variable (
380401 f"coefs_{ self .name } " ,
381- shape = (k_unique_states ,) if k_endog == 1 else (k_endog , k_unique_states ),
402+ shape = (k_unique_states ,)
403+ if k_endog_effective == 1
404+ else (k_endog_effective , k_unique_states ),
382405 )
383- if k_endog == 1 :
406+ if k_endog_effective == 1 :
384407 self .ssm ["initial_state" , :] = pt .extra_ops .repeat (initial_states , duration , axis = 0 )
385408 else :
386409 self .ssm ["initial_state" , :] = pt .extra_ops .repeat (
@@ -389,11 +412,11 @@ def make_symbolic_graph(self) -> None:
389412
390413 if self .innovations :
391414 R = pt .zeros ((k_states , k_posdef ))[0 , 0 ].set (1.0 )
392- self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog )])
415+ self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog_effective )])
393416 season_sigma = self .make_and_register_variable (
394- f"sigma_{ self .name } " , shape = () if k_endog == 1 else (k_endog ,)
417+ f"sigma_{ self .name } " , shape = () if k_endog_effective == 1 else (k_endog_effective ,)
395418 )
396- cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog ))
419+ cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog_effective ))
397420 self .ssm [cov_idx ] = season_sigma ** 2
398421
399422
0 commit comments