@@ -29,8 +29,9 @@ class BayesianDynamicFactor(PyMCStateSpace):
29
29
k_factors : int
30
30
Number of latent factors.
31
31
32
- factor_order : int
33
- Order of the VAR process for the latent factors.
32
+ factor_order : int or Sequence[int]
33
+ Order of the VAR process for the latent factors. If an integer is provided, the same order is used for all factors.
34
+ If a sequence of integers is provided, it specifies the order for each factor individually.
34
35
35
36
k_endog : int, optional
36
37
Number of observed time series. If not provided, the number of observed series will be inferred from `endog_names`.
@@ -156,7 +157,7 @@ class BayesianDynamicFactor(PyMCStateSpace):
156
157
def __init__ (
157
158
self ,
158
159
k_factors : int ,
159
- factor_order : int ,
160
+ factor_order : int | Sequence [ int ] ,
160
161
k_endog : int | None = None ,
161
162
endog_names : Sequence [str ] | None = None ,
162
163
exog : np .ndarray | None = None ,
@@ -182,6 +183,17 @@ def __init__(
182
183
if exog is not None :
183
184
raise NotImplementedError ("Exogenous variables (exog) are not yet implemented." )
184
185
186
+ # Normalize factor_order to a list of length k_factors
187
+ if isinstance (factor_order , int ):
188
+ factor_order = [factor_order ] * k_factors
189
+ elif isinstance (factor_order , Sequence ):
190
+ if len (factor_order ) != k_factors :
191
+ raise ValueError (
192
+ f"factor_order must have length { k_factors } when given as a sequence."
193
+ )
194
+ else :
195
+ raise TypeError ("factor_order must be either an int or a sequence of ints." )
196
+
185
197
self .endog_names = endog_names
186
198
self .k_endog = k_endog
187
199
self .k_factors = k_factors
@@ -195,8 +207,12 @@ def __init__(
195
207
# Determine the dimension for the latent factor states.
196
208
# For static factors, one might use k_factors.
197
209
# For dynamic factors with lags, the state might include current factors and past lags.
198
- # TODO: what if we want different factor orders for different factors? (follow suggestions in GitHub)
199
- k_factor_states = k_factors * factor_order
210
+ # If factor_order is 0, we treat the factor as static (no dynamics),
211
+ # but it is still included in the state vector with one state per factor.
212
+ # Factor_ar paramter will not exist in this case.
213
+ k_factor_states = sum (max (order , 1 ) for order in self .factor_order )
214
+
215
+ self ._max_order = max (self .factor_order )
200
216
201
217
# Determine the dimension for the error component.
202
218
# If error_order > 0 then we add additional states for error dynamics, otherwise white noise error.
@@ -234,7 +250,7 @@ def param_names(self):
234
250
]
235
251
236
252
# Handle cases where parameters should be excluded based on model settings
237
- if self . factor_order == 0 :
253
+ if all ( order == 0 for order in self . factor_order ) :
238
254
names .remove ("factor_ar" )
239
255
if self .error_order == 0 :
240
256
names .remove ("error_ar" )
@@ -262,7 +278,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
262
278
"constraints" : None ,
263
279
},
264
280
"factor_ar" : {
265
- "shape" : (self .k_factors , self .factor_order ),
281
+ "shape" : (self .k_factors , self ._max_order ),
266
282
"constraints" : None ,
267
283
},
268
284
"factor_sigma" : {
@@ -283,7 +299,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
283
299
},
284
300
"sigma_obs" : {
285
301
"shape" : (self .k_endog ,),
286
- "constraints" : "Positive Semi-definite " ,
302
+ "constraints" : "Positive" ,
287
303
},
288
304
}
289
305
@@ -301,9 +317,9 @@ def state_names(self) -> list[str]:
301
317
names = []
302
318
# TODO adjust notation by looking at the VARMAX implementation
303
319
# Factor states
304
- for i in range (self .k_factors ):
305
- for lag in range (self . factor_order ):
306
- names .append (f"L { lag } . factor_{ i + 1 } " )
320
+ for i , order in enumerate (self .factor_order ):
321
+ for j in range (max ( order , 1 ) ):
322
+ names .append (f"factor_{ i } _ { j } " )
307
323
308
324
# Idiosyncratic error states
309
325
if self .error_order > 0 :
@@ -327,8 +343,8 @@ def coords(self) -> dict[str, Sequence]:
327
343
coords [FACTOR_DIM ] = [f"factor_{ i + 1 } " for i in range (self .k_factors )]
328
344
329
345
# AR parameter dimensions - add if needed
330
- if self .factor_order > 0 :
331
- coords [AR_PARAM_DIM ] = list (range (1 , self .factor_order + 1 ))
346
+ if self ._max_order > 0 :
347
+ coords [AR_PARAM_DIM ] = list (range (1 , self ._max_order + 1 ))
332
348
333
349
# If error_order > 0
334
350
if self .error_order > 0 :
@@ -359,19 +375,21 @@ def param_dims(self):
359
375
"factor_loadings" : (OBS_STATE_DIM , FACTOR_DIM ),
360
376
"factor_sigma" : (FACTOR_DIM ,),
361
377
}
362
-
363
- if self .factor_order > 0 :
378
+ if self ._max_order > 0 :
364
379
coord_map ["factor_ar" ] = (FACTOR_DIM , AR_PARAM_DIM )
365
380
366
381
if self .error_order > 0 :
367
382
coord_map ["error_ar" ] = (OBS_STATE_DIM , ERROR_AR_PARAM_DIM )
368
383
369
384
if self .error_cov_type in ["scalar" ]:
370
385
coord_map ["error_sigma" ] = ()
386
+
371
387
elif self .error_cov_type in ["diagonal" ]:
372
388
coord_map ["error_sigma" ] = (OBS_STATE_DIM ,)
389
+
373
390
if self .error_cov_type == "unstructured" :
374
391
coord_map ["error_sigma" ] = (OBS_STATE_DIM , OBS_STATE_AUX_DIM )
392
+
375
393
if self .measurement_error :
376
394
coord_map ["sigma_obs" ] = (OBS_STATE_DIM ,)
377
395
return coord_map
@@ -396,12 +414,13 @@ def make_symbolic_graph(self):
396
414
397
415
self .ssm ["design" , :, :] = 0.0
398
416
399
- for j in range (self .k_factors ):
400
- col_idx = j * self .factor_order
401
- self .ssm ["design" , :, col_idx ] = factor_loadings [:, j ]
417
+ factor_col = 0
418
+ for i , order in enumerate (self .factor_order ):
419
+ self .ssm ["design" , :, factor_col ] = factor_loadings [:, i ]
420
+ factor_col += max (order , 1 )
402
421
403
422
for i in range (self .k_endog ):
404
- col_idx = self . k_factors * self .factor_order + i * self .error_order
423
+ col_idx = sum ( max ( order , 1 ) for order in self .factor_order ) + i * self .error_order
405
424
self .ssm ["design" , i , col_idx ] = 1.0
406
425
407
426
# Transition matrix
@@ -415,11 +434,20 @@ def build_ar_block_matrix(ar_coeffs):
415
434
416
435
transition_blocks = []
417
436
418
- factor_ar = self .make_and_register_variable (
419
- "factor_ar" , shape = (self .k_factors , self .factor_order ), dtype = floatX
420
- )
421
- for j in range (self .k_factors ):
422
- transition_blocks .append (build_ar_block_matrix (factor_ar [j ]))
437
+ if self ._max_order > 0 :
438
+ factor_ar = self .make_and_register_variable (
439
+ "factor_ar" , shape = (self .k_factors , self ._max_order ), dtype = floatX
440
+ )
441
+ for j in range (self .k_factors ):
442
+ order = self .factor_order [j ]
443
+ if order == 0 :
444
+ # For order=0, just add a 1x1 zero matrix (static factor)
445
+ transition_blocks .append (pt .zeros ((1 , 1 ), dtype = floatX ))
446
+ else :
447
+ transition_blocks .append (build_ar_block_matrix (factor_ar [j ][:order ]))
448
+ else :
449
+ # If no factor dynamics, just add a zero matrix
450
+ transition_blocks .append (pt .zeros ((self .k_factors , self .k_factors ), dtype = floatX ))
423
451
424
452
if self .error_order > 0 :
425
453
error_ar = self .make_and_register_variable (
@@ -434,12 +462,13 @@ def build_ar_block_matrix(ar_coeffs):
434
462
# Selection matrix
435
463
self .ssm ["selection" , :, :] = 0.0
436
464
437
- for i in range (self .k_factors ):
438
- row = i * self .factor_order
439
- self .ssm ["selection" , row , i ] = 1.0
465
+ factor_row = 0
466
+ for i , order in enumerate (self .factor_order ):
467
+ self .ssm ["selection" , factor_row , i ] = 1.0
468
+ factor_row += max (order , 1 )
440
469
441
470
for i in range (self .k_endog ):
442
- row = self .k_factors * self . factor_order + i * self .error_order
471
+ row = sum ( self .factor_order ) + i * self .error_order
443
472
col = self .k_factors + i
444
473
self .ssm ["selection" , row , col ] = 1.0
445
474
@@ -473,6 +502,3 @@ def build_ar_block_matrix(ar_coeffs):
473
502
"sigma_obs" , shape = (self .k_endog ,), dtype = floatX
474
503
)
475
504
self .ssm ["obs_cov" , :, :] = pt .diag (sigma_obs )
476
- else :
477
- # If measurement error is not used, set obs_cov to zero
478
- self .ssm ["obs_cov" , :, :] = 0.0
0 commit comments