Skip to content

Commit eaba40e

Browse files
committed
Added support for different autoregressive order for factor
Also addressed review comments on sigma_obs and measurement error
1 parent bc0b5c8 commit eaba40e

File tree

1 file changed

+57
-31
lines changed
  • pymc_extras/statespace/models

1 file changed

+57
-31
lines changed

pymc_extras/statespace/models/DFM.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ class BayesianDynamicFactor(PyMCStateSpace):
2929
k_factors : int
3030
Number of latent factors.
3131
32-
factor_order : int
33-
Order of the VAR process for the latent factors.
32+
factor_order : int or Sequence[int]
33+
Order of the VAR process for the latent factors. If an integer is provided, the same order is used for all factors.
34+
If a sequence of integers is provided, it specifies the order for each factor individually.
3435
3536
k_endog : int, optional
3637
Number of observed time series. If not provided, the number of observed series will be inferred from `endog_names`.
@@ -156,7 +157,7 @@ class BayesianDynamicFactor(PyMCStateSpace):
156157
def __init__(
157158
self,
158159
k_factors: int,
159-
factor_order: int,
160+
factor_order: int | Sequence[int],
160161
k_endog: int | None = None,
161162
endog_names: Sequence[str] | None = None,
162163
exog: np.ndarray | None = None,
@@ -182,6 +183,17 @@ def __init__(
182183
if exog is not None:
183184
raise NotImplementedError("Exogenous variables (exog) are not yet implemented.")
184185

186+
# Normalize factor_order to a list of length k_factors
187+
if isinstance(factor_order, int):
188+
factor_order = [factor_order] * k_factors
189+
elif isinstance(factor_order, Sequence):
190+
if len(factor_order) != k_factors:
191+
raise ValueError(
192+
f"factor_order must have length {k_factors} when given as a sequence."
193+
)
194+
else:
195+
raise TypeError("factor_order must be either an int or a sequence of ints.")
196+
185197
self.endog_names = endog_names
186198
self.k_endog = k_endog
187199
self.k_factors = k_factors
@@ -195,8 +207,12 @@ def __init__(
195207
# Determine the dimension for the latent factor states.
196208
# For static factors, one might use k_factors.
197209
# For dynamic factors with lags, the state might include current factors and past lags.
198-
# TODO: what if we want different factor orders for different factors? (follow suggestions in GitHub)
199-
k_factor_states = k_factors * factor_order
210+
# If factor_order is 0, we treat the factor as static (no dynamics),
211+
# but it is still included in the state vector with one state per factor.
212+
# Factor_ar paramter will not exist in this case.
213+
k_factor_states = sum(max(order, 1) for order in self.factor_order)
214+
215+
self._max_order = max(self.factor_order)
200216

201217
# Determine the dimension for the error component.
202218
# If error_order > 0 then we add additional states for error dynamics, otherwise white noise error.
@@ -234,7 +250,7 @@ def param_names(self):
234250
]
235251

236252
# Handle cases where parameters should be excluded based on model settings
237-
if self.factor_order == 0:
253+
if all(order == 0 for order in self.factor_order):
238254
names.remove("factor_ar")
239255
if self.error_order == 0:
240256
names.remove("error_ar")
@@ -262,7 +278,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
262278
"constraints": None,
263279
},
264280
"factor_ar": {
265-
"shape": (self.k_factors, self.factor_order),
281+
"shape": (self.k_factors, self._max_order),
266282
"constraints": None,
267283
},
268284
"factor_sigma": {
@@ -283,7 +299,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
283299
},
284300
"sigma_obs": {
285301
"shape": (self.k_endog,),
286-
"constraints": "Positive Semi-definite",
302+
"constraints": "Positive",
287303
},
288304
}
289305

@@ -301,9 +317,9 @@ def state_names(self) -> list[str]:
301317
names = []
302318
# TODO adjust notation by looking at the VARMAX implementation
303319
# Factor states
304-
for i in range(self.k_factors):
305-
for lag in range(self.factor_order):
306-
names.append(f"L{lag}.factor_{i+1}")
320+
for i, order in enumerate(self.factor_order):
321+
for j in range(max(order, 1)):
322+
names.append(f"factor_{i}_{j}")
307323

308324
# Idiosyncratic error states
309325
if self.error_order > 0:
@@ -327,8 +343,8 @@ def coords(self) -> dict[str, Sequence]:
327343
coords[FACTOR_DIM] = [f"factor_{i+1}" for i in range(self.k_factors)]
328344

329345
# AR parameter dimensions - add if needed
330-
if self.factor_order > 0:
331-
coords[AR_PARAM_DIM] = list(range(1, self.factor_order + 1))
346+
if self._max_order > 0:
347+
coords[AR_PARAM_DIM] = list(range(1, self._max_order + 1))
332348

333349
# If error_order > 0
334350
if self.error_order > 0:
@@ -359,19 +375,21 @@ def param_dims(self):
359375
"factor_loadings": (OBS_STATE_DIM, FACTOR_DIM),
360376
"factor_sigma": (FACTOR_DIM,),
361377
}
362-
363-
if self.factor_order > 0:
378+
if self._max_order > 0:
364379
coord_map["factor_ar"] = (FACTOR_DIM, AR_PARAM_DIM)
365380

366381
if self.error_order > 0:
367382
coord_map["error_ar"] = (OBS_STATE_DIM, ERROR_AR_PARAM_DIM)
368383

369384
if self.error_cov_type in ["scalar"]:
370385
coord_map["error_sigma"] = ()
386+
371387
elif self.error_cov_type in ["diagonal"]:
372388
coord_map["error_sigma"] = (OBS_STATE_DIM,)
389+
373390
if self.error_cov_type == "unstructured":
374391
coord_map["error_sigma"] = (OBS_STATE_DIM, OBS_STATE_AUX_DIM)
392+
375393
if self.measurement_error:
376394
coord_map["sigma_obs"] = (OBS_STATE_DIM,)
377395
return coord_map
@@ -396,12 +414,13 @@ def make_symbolic_graph(self):
396414

397415
self.ssm["design", :, :] = 0.0
398416

399-
for j in range(self.k_factors):
400-
col_idx = j * self.factor_order
401-
self.ssm["design", :, col_idx] = factor_loadings[:, j]
417+
factor_col = 0
418+
for i, order in enumerate(self.factor_order):
419+
self.ssm["design", :, factor_col] = factor_loadings[:, i]
420+
factor_col += max(order, 1)
402421

403422
for i in range(self.k_endog):
404-
col_idx = self.k_factors * self.factor_order + i * self.error_order
423+
col_idx = sum(max(order, 1) for order in self.factor_order) + i * self.error_order
405424
self.ssm["design", i, col_idx] = 1.0
406425

407426
# Transition matrix
@@ -415,11 +434,20 @@ def build_ar_block_matrix(ar_coeffs):
415434

416435
transition_blocks = []
417436

418-
factor_ar = self.make_and_register_variable(
419-
"factor_ar", shape=(self.k_factors, self.factor_order), dtype=floatX
420-
)
421-
for j in range(self.k_factors):
422-
transition_blocks.append(build_ar_block_matrix(factor_ar[j]))
437+
if self._max_order > 0:
438+
factor_ar = self.make_and_register_variable(
439+
"factor_ar", shape=(self.k_factors, self._max_order), dtype=floatX
440+
)
441+
for j in range(self.k_factors):
442+
order = self.factor_order[j]
443+
if order == 0:
444+
# For order=0, just add a 1x1 zero matrix (static factor)
445+
transition_blocks.append(pt.zeros((1, 1), dtype=floatX))
446+
else:
447+
transition_blocks.append(build_ar_block_matrix(factor_ar[j][:order]))
448+
else:
449+
# If no factor dynamics, just add a zero matrix
450+
transition_blocks.append(pt.zeros((self.k_factors, self.k_factors), dtype=floatX))
423451

424452
if self.error_order > 0:
425453
error_ar = self.make_and_register_variable(
@@ -434,12 +462,13 @@ def build_ar_block_matrix(ar_coeffs):
434462
# Selection matrix
435463
self.ssm["selection", :, :] = 0.0
436464

437-
for i in range(self.k_factors):
438-
row = i * self.factor_order
439-
self.ssm["selection", row, i] = 1.0
465+
factor_row = 0
466+
for i, order in enumerate(self.factor_order):
467+
self.ssm["selection", factor_row, i] = 1.0
468+
factor_row += max(order, 1)
440469

441470
for i in range(self.k_endog):
442-
row = self.k_factors * self.factor_order + i * self.error_order
471+
row = sum(self.factor_order) + i * self.error_order
443472
col = self.k_factors + i
444473
self.ssm["selection", row, col] = 1.0
445474

@@ -473,6 +502,3 @@ def build_ar_block_matrix(ar_coeffs):
473502
"sigma_obs", shape=(self.k_endog,), dtype=floatX
474503
)
475504
self.ssm["obs_cov", :, :] = pt.diag(sigma_obs)
476-
else:
477-
# If measurement error is not used, set obs_cov to zero
478-
self.ssm["obs_cov", :, :] = 0.0

0 commit comments

Comments
 (0)