@@ -44,6 +44,11 @@ class TimeSeasonality(Component):
4444 observed_state_names: list[str] | None, default None
4545 List of strings for observed state labels. If None, defaults to ["data"].
4646
47+ share_states: bool, default False
48+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
49+ states, which are observed by all observed states. If False, each observed state has its own set of
50+ latent states. This argument has no effect if `k_endog` is 1.
51+
4752 Notes
4853 -----
4954 A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
@@ -235,6 +240,7 @@ def __init__(
235240 state_names : list | None = None ,
236241 remove_first_state : bool = True ,
237242 observed_state_names : list [str ] | None = None ,
243+ share_states : bool = False ,
238244 ):
239245 if observed_state_names is None :
240246 observed_state_names = ["data" ]
@@ -261,6 +267,7 @@ def __init__(
261267 )
262268 state_names = state_names .copy ()
263269
270+ self .share_states = share_states
264271 self .innovations = innovations
265272 self .duration = duration
266273 self .remove_first_state = remove_first_state
@@ -281,44 +288,54 @@ def __init__(
281288 super ().__init__ (
282289 name = name ,
283290 k_endog = k_endog ,
284- k_states = k_states * k_endog ,
285- k_posdef = k_posdef * k_endog ,
291+ k_states = k_states if share_states else k_states * k_endog ,
292+ k_posdef = k_posdef if share_states else k_posdef * k_endog ,
286293 observed_state_names = observed_state_names ,
287294 measurement_error = False ,
288295 combine_hidden_states = True ,
289- obs_state_idxs = np .tile (np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), k_endog ),
296+ obs_state_idxs = np .tile (
297+ np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), 1 if share_states else k_endog
298+ ),
290299 )
291300
292301 def populate_component_properties (self ):
293- k_states = self .k_states // self .k_endog
294302 k_endog = self .k_endog
303+ k_endog_effective = 1 if self .share_states else k_endog
304+
305+ k_states = self .k_states // k_endog_effective
306+
307+ if self .share_states :
308+ self .state_names = [
309+ f"{ state_name } [{ self .name } _shared]" for state_name in self .provided_state_names
310+ ]
311+ else :
312+ self .state_names = [
313+ f"{ state_name } [{ endog_name } ]"
314+ for endog_name in self .observed_state_names
315+ for state_name in self .provided_state_names
316+ ]
295317
296- self .state_names = [
297- f"{ state_name } [{ endog_name } ]"
298- for endog_name in self .observed_state_names
299- for state_name in self .provided_state_names
300- ]
301318 self .param_names = [f"params_{ self .name } " ]
302319
303320 self .param_info = {
304321 f"params_{ self .name } " : {
305322 "shape" : (k_states ,) if k_endog == 1 else (k_endog , k_states ),
306323 "constraints" : None ,
307324 "dims" : (f"state_{ self .name } " ,)
308- if k_endog == 1
325+ if k_endog_effective == 1
309326 else (f"endog_{ self .name } " , f"state_{ self .name } " ),
310327 }
311328 }
312329
313330 self .param_dims = {
314331 f"params_{ self .name } " : (f"state_{ self .name } " ,)
315- if k_endog == 1
332+ if k_endog_effective == 1
316333 else (f"endog_{ self .name } " , f"state_{ self .name } " )
317334 }
318335
319336 self .coords = (
320337 {f"state_{ self .name } " : self .provided_state_names }
321- if k_endog == 1
338+ if k_endog_effective == 1
322339 else {
323340 f"endog_{ self .name } " : self .observed_state_names ,
324341 f"state_{ self .name } " : self .provided_state_names ,
@@ -327,21 +344,27 @@ def populate_component_properties(self):
327344
328345 if self .innovations :
329346 self .param_names += [f"sigma_{ self .name } " ]
330- self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
331347 self .param_info [f"sigma_{ self .name } " ] = {
332- "shape" : () if k_endog == 1 else (k_endog ,),
348+ "shape" : () if k_endog_effective == 1 else (k_endog ,),
333349 "constraints" : "Positive" ,
334- "dims" : None if k_endog == 1 else (f"endog_{ self .name } " ,),
350+ "dims" : None if k_endog_effective == 1 else (f"endog_{ self .name } " ,),
335351 }
352+ if self .share_states :
353+ self .shock_names = [f"{ self .name } [shared]" ]
354+ else :
355+ self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
356+
336357 if k_endog > 1 :
337358 self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
338359
339360 def make_symbolic_graph (self ) -> None :
340- k_states = self .k_states // self .k_endog
361+ k_endog = self .k_endog
362+ k_endog_effective = 1 if self .share_states else k_endog
363+ k_states = self .k_states // k_endog_effective
341364 duration = self .duration
365+
342366 k_unique_states = k_states // duration
343- k_posdef = self .k_posdef // self .k_endog
344- k_endog = self .k_endog
367+ k_posdef = self .k_posdef // k_endog_effective
345368
346369 if self .remove_first_state :
347370 # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
@@ -373,16 +396,18 @@ def make_symbolic_graph(self) -> None:
373396 T = pt .eye (k_states , k = 1 )
374397 T = pt .set_subtensor (T [- 1 , 0 ], 1 )
375398
376- self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog )])
399+ self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog_effective )])
377400
378401 Z = pt .zeros ((1 , k_states ))[0 , 0 ].set (1 )
379- self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
402+ self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog_effective )])
380403
381404 initial_states = self .make_and_register_variable (
382405 f"params_{ self .name } " ,
383- shape = (k_unique_states ,) if k_endog == 1 else (k_endog , k_unique_states ),
406+ shape = (k_unique_states ,)
407+ if k_endog_effective == 1
408+ else (k_endog_effective , k_unique_states ),
384409 )
385- if k_endog == 1 :
410+ if k_endog_effective == 1 :
386411 self .ssm ["initial_state" , :] = pt .extra_ops .repeat (initial_states , duration , axis = 0 )
387412 else :
388413 self .ssm ["initial_state" , :] = pt .extra_ops .repeat (
@@ -391,11 +416,11 @@ def make_symbolic_graph(self) -> None:
391416
392417 if self .innovations :
393418 R = pt .zeros ((k_states , k_posdef ))[0 , 0 ].set (1.0 )
394- self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog )])
419+ self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog_effective )])
395420 season_sigma = self .make_and_register_variable (
396- f"sigma_{ self .name } " , shape = () if k_endog == 1 else (k_endog ,)
421+ f"sigma_{ self .name } " , shape = () if k_endog_effective == 1 else (k_endog_effective ,)
397422 )
398- cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog ))
423+ cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog_effective ))
399424 self .ssm [cov_idx ] = season_sigma ** 2
400425
401426
0 commit comments