@@ -44,6 +44,11 @@ class TimeSeasonality(Component):
44
44
observed_state_names: list[str] | None, default None
45
45
List of strings for observed state labels. If None, defaults to ["data"].
46
46
47
+ share_states: bool, default False
48
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
49
+ states, which are observed by all observed states. If False, each observed state has its own set of
50
+ latent states. This argument has no effect if `k_endog` is 1.
51
+
47
52
Notes
48
53
-----
49
54
A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
@@ -235,6 +240,7 @@ def __init__(
235
240
state_names : list | None = None ,
236
241
remove_first_state : bool = True ,
237
242
observed_state_names : list [str ] | None = None ,
243
+ share_states : bool = False ,
238
244
):
239
245
if observed_state_names is None :
240
246
observed_state_names = ["data" ]
@@ -261,6 +267,7 @@ def __init__(
261
267
)
262
268
state_names = state_names .copy ()
263
269
270
+ self .share_states = share_states
264
271
self .innovations = innovations
265
272
self .duration = duration
266
273
self .remove_first_state = remove_first_state
@@ -281,44 +288,54 @@ def __init__(
281
288
super ().__init__ (
282
289
name = name ,
283
290
k_endog = k_endog ,
284
- k_states = k_states * k_endog ,
285
- k_posdef = k_posdef * k_endog ,
291
+ k_states = k_states if share_states else k_states * k_endog ,
292
+ k_posdef = k_posdef if share_states else k_posdef * k_endog ,
286
293
observed_state_names = observed_state_names ,
287
294
measurement_error = False ,
288
295
combine_hidden_states = True ,
289
- obs_state_idxs = np .tile (np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), k_endog ),
296
+ obs_state_idxs = np .tile (
297
+ np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), 1 if share_states else k_endog
298
+ ),
290
299
)
291
300
292
301
def populate_component_properties (self ):
293
- k_states = self .k_states // self .k_endog
294
302
k_endog = self .k_endog
303
+ k_endog_effective = 1 if self .share_states else k_endog
304
+
305
+ k_states = self .k_states // k_endog_effective
306
+
307
+ if self .share_states :
308
+ self .state_names = [
309
+ f"{ state_name } [{ self .name } _shared]" for state_name in self .provided_state_names
310
+ ]
311
+ else :
312
+ self .state_names = [
313
+ f"{ state_name } [{ endog_name } ]"
314
+ for endog_name in self .observed_state_names
315
+ for state_name in self .provided_state_names
316
+ ]
295
317
296
- self .state_names = [
297
- f"{ state_name } [{ endog_name } ]"
298
- for endog_name in self .observed_state_names
299
- for state_name in self .provided_state_names
300
- ]
301
318
self .param_names = [f"params_{ self .name } " ]
302
319
303
320
self .param_info = {
304
321
f"params_{ self .name } " : {
305
322
"shape" : (k_states ,) if k_endog == 1 else (k_endog , k_states ),
306
323
"constraints" : None ,
307
324
"dims" : (f"state_{ self .name } " ,)
308
- if k_endog == 1
325
+ if k_endog_effective == 1
309
326
else (f"endog_{ self .name } " , f"state_{ self .name } " ),
310
327
}
311
328
}
312
329
313
330
self .param_dims = {
314
331
f"params_{ self .name } " : (f"state_{ self .name } " ,)
315
- if k_endog == 1
332
+ if k_endog_effective == 1
316
333
else (f"endog_{ self .name } " , f"state_{ self .name } " )
317
334
}
318
335
319
336
self .coords = (
320
337
{f"state_{ self .name } " : self .provided_state_names }
321
- if k_endog == 1
338
+ if k_endog_effective == 1
322
339
else {
323
340
f"endog_{ self .name } " : self .observed_state_names ,
324
341
f"state_{ self .name } " : self .provided_state_names ,
@@ -327,21 +344,27 @@ def populate_component_properties(self):
327
344
328
345
if self .innovations :
329
346
self .param_names += [f"sigma_{ self .name } " ]
330
- self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
331
347
self .param_info [f"sigma_{ self .name } " ] = {
332
- "shape" : () if k_endog == 1 else (k_endog ,),
348
+ "shape" : () if k_endog_effective == 1 else (k_endog ,),
333
349
"constraints" : "Positive" ,
334
- "dims" : None if k_endog == 1 else (f"endog_{ self .name } " ,),
350
+ "dims" : None if k_endog_effective == 1 else (f"endog_{ self .name } " ,),
335
351
}
352
+ if self .share_states :
353
+ self .shock_names = [f"{ self .name } [shared]" ]
354
+ else :
355
+ self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
356
+
336
357
if k_endog > 1 :
337
358
self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
338
359
339
360
def make_symbolic_graph (self ) -> None :
340
- k_states = self .k_states // self .k_endog
361
+ k_endog = self .k_endog
362
+ k_endog_effective = 1 if self .share_states else k_endog
363
+ k_states = self .k_states // k_endog_effective
341
364
duration = self .duration
365
+
342
366
k_unique_states = k_states // duration
343
- k_posdef = self .k_posdef // self .k_endog
344
- k_endog = self .k_endog
367
+ k_posdef = self .k_posdef // k_endog_effective
345
368
346
369
if self .remove_first_state :
347
370
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
@@ -373,16 +396,18 @@ def make_symbolic_graph(self) -> None:
373
396
T = pt .eye (k_states , k = 1 )
374
397
T = pt .set_subtensor (T [- 1 , 0 ], 1 )
375
398
376
- self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog )])
399
+ self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog_effective )])
377
400
378
401
Z = pt .zeros ((1 , k_states ))[0 , 0 ].set (1 )
379
- self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
402
+ self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog_effective )])
380
403
381
404
initial_states = self .make_and_register_variable (
382
405
f"params_{ self .name } " ,
383
- shape = (k_unique_states ,) if k_endog == 1 else (k_endog , k_unique_states ),
406
+ shape = (k_unique_states ,)
407
+ if k_endog_effective == 1
408
+ else (k_endog_effective , k_unique_states ),
384
409
)
385
- if k_endog == 1 :
410
+ if k_endog_effective == 1 :
386
411
self .ssm ["initial_state" , :] = pt .extra_ops .repeat (initial_states , duration , axis = 0 )
387
412
else :
388
413
self .ssm ["initial_state" , :] = pt .extra_ops .repeat (
@@ -391,11 +416,11 @@ def make_symbolic_graph(self) -> None:
391
416
392
417
if self .innovations :
393
418
R = pt .zeros ((k_states , k_posdef ))[0 , 0 ].set (1.0 )
394
- self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog )])
419
+ self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog_effective )])
395
420
season_sigma = self .make_and_register_variable (
396
- f"sigma_{ self .name } " , shape = () if k_endog == 1 else (k_endog ,)
421
+ f"sigma_{ self .name } " , shape = () if k_endog_effective == 1 else (k_endog_effective ,)
397
422
)
398
- cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog ))
423
+ cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog_effective ))
399
424
self .ssm [cov_idx ] = season_sigma ** 2
400
425
401
426
0 commit comments