Skip to content

Commit fbd5b1f

Browse files
committed
Adding a first implemntation of exogeneous variable support based on pymc_extras/statespace/models/structural/components/regression.py
1 parent 05c11d4 commit fbd5b1f

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

pymc_extras/statespace/models/DFM.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
ALL_STATE_DIM,
1313
AR_PARAM_DIM,
1414
ERROR_AR_PARAM_DIM,
15+
EXOG_STATE_DIM,
1516
FACTOR_DIM,
1617
OBS_STATE_AUX_DIM,
1718
OBS_STATE_DIM,
19+
TIME_DIM,
1820
)
1921

2022
floatX = pytensor.config.floatX
@@ -283,8 +285,8 @@ def __init__(
283285
if endog_names is None:
284286
endog_names = [f"endog_{i}" for i in range(k_endog)]
285287

286-
if k_exog is not None or exog_names is not None:
287-
raise NotImplementedError("Exogenous variables (exog) are not yet implemented.")
288+
# if k_exog is not None or exog_names is not None:
289+
# raise NotImplementedError("Exogenous variables (exog) are not yet implemented.")
288290

289291
self.endog_names = endog_names
290292
self.k_endog = k_endog
@@ -293,7 +295,23 @@ def __init__(
293295
self.error_order = error_order
294296
self.error_var = error_var
295297
self.error_cov_type = error_cov_type
296-
# TODO add exogenous variables support
298+
299+
if k_exog is None and exog_names is None:
300+
self._exog = False
301+
self.k_exog = 0
302+
else:
303+
self._exog = True
304+
if k_exog is None:
305+
k_exog = len(exog_names) if exog_names is not None else 0
306+
elif exog_names is None:
307+
exog_names = [f"exog_{i}" for i in range(k_exog)] if k_exog > 0 else None
308+
309+
self.exog_names = exog_names
310+
self.k_exog = k_exog
311+
312+
# TODO add exogenous variables support (statsmodel dealt with exog without touching state vector,but just working on the measurement equation)
313+
# I start implementing a version of exog support, with shared_states=False based on pymc_extras/statespace/models/structural/components/regression.py
314+
# currently the beta coefficients are time invariant, so the innovation on beta are not supported
297315

298316
# Determine the dimension for the latent factor states.
299317
# For static factors, one use k_factors.
@@ -308,11 +326,12 @@ def __init__(
308326
k_error_states = k_endog * error_order if error_order > 0 else 0
309327

310328
# Total state dimension
311-
k_states = k_factor_states + k_error_states
329+
k_states = k_factor_states + k_error_states + (k_exog * k_endog if self._exog else 0)
312330

313331
# Number of independent shocks.
314332
# Typically, the latent factors introduce k_factors shocks.
315333
# If error_order > 0 and errors are modeled jointly or separately, add appropriate count.
334+
# TODO currently the implementation does not support for innovation on betas coefficient
316335
k_posdef = k_factors + (k_endog if error_order > 0 else 0)
317336

318337
# Initialize the PyMCStateSpace base class.
@@ -346,6 +365,8 @@ def param_names(self):
346365
names.append("error_cov")
347366
if not self.measurement_error:
348367
names.remove("sigma_obs")
368+
if self._exog:
369+
names.append("beta")
349370

350371
return names
351372

@@ -387,6 +408,10 @@ def param_info(self) -> dict[str, dict[str, Any]]:
387408
"shape": (self.k_endog,),
388409
"constraints": "Positive",
389410
},
411+
"beta": {
412+
"shape": (self.k_exog * self.k_endog if self.k_exog is not None else 0,),
413+
"constraints": None,
414+
},
390415
}
391416

392417
for name in self.param_names:
@@ -398,7 +423,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
398423
def state_names(self) -> list[str]:
399424
"""
400425
Returns the names of the hidden states: first factor states (with lags),
401-
then idiosyncratic error states (with lags).
426+
idiosyncratic error states (with lags), then exogenous states.
402427
"""
403428
names = []
404429
# Factor states
@@ -412,6 +437,12 @@ def state_names(self) -> list[str]:
412437
for lag in range(self.error_order):
413438
names.append(f"L{lag}.error_{i}")
414439

440+
if self._exog:
441+
# Exogenous states
442+
for i in range(self.k_exog):
443+
for j in range(self.k_endog):
444+
names.append(f"exog_{i}.endog_{j}")
445+
415446
return names
416447

417448
@property
@@ -438,6 +469,10 @@ def coords(self) -> dict[str, Sequence]:
438469
else:
439470
coords[ERROR_AR_PARAM_DIM] = list(range(1, self.error_order + 1))
440471

472+
if self._exog:
473+
# Exogenous states
474+
coords[EXOG_STATE_DIM] = list(range(1, (self.k_exog * self.k_endog) + 1))
475+
441476
return coords
442477

443478
@property
@@ -479,8 +514,30 @@ def param_dims(self):
479514

480515
if self.measurement_error:
481516
coord_map["sigma_obs"] = (OBS_STATE_DIM,)
517+
518+
if self._exog:
519+
coord_map["beta"] = (EXOG_STATE_DIM,)
520+
# coord_map["exog_data"]
521+
482522
return coord_map
483523

524+
@property
525+
def data_info(self):
526+
if self._exog:
527+
return {
528+
"exog_data": {
529+
"shape": (None, self.k_exog),
530+
"dims": (TIME_DIM, EXOG_STATE_DIM),
531+
},
532+
}
533+
return {}
534+
535+
@property
536+
def data_names(self):
537+
if self._exog:
538+
return ["exog_data"]
539+
return []
540+
484541
def make_symbolic_graph(self):
485542
# Initial states
486543
x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX)
@@ -498,13 +555,41 @@ def make_symbolic_graph(self):
498555
"factor_loadings", shape=(self.k_endog, self.k_factors), dtype=floatX
499556
)
500557

501-
for i in range(self.k_factors):
502-
self.ssm["design", :, i] = factor_loadings[:, i]
558+
# Start with factor loadings
559+
matrix_parts = [factor_loadings]
560+
561+
if self.factor_order > 1:
562+
matrix_parts.append(
563+
pt.zeros((self.k_endog, self.k_factors * (self.factor_order - 1)), dtype=floatX)
564+
)
503565

504566
if self.error_order > 0:
505-
for i in range(self.k_endog):
506-
col_idx = max(self.factor_order, 1) * self.k_factors + i
507-
self.ssm["design", i, col_idx] = 1.0
567+
# Create identity matrix for error terms
568+
error_matrix = pt.eye(self.k_endog, dtype=floatX)
569+
matrix_parts.append(error_matrix)
570+
matrix_parts.append(
571+
pt.zeros((self.k_endog, self.k_endog * (self.error_order - 1)), dtype=floatX)
572+
)
573+
574+
# Concatenate all parts
575+
design_matrix = pt.concatenate(matrix_parts, axis=1)
576+
577+
if self._exog:
578+
exog_data = self.make_and_register_data("exog_data", shape=(None, self.k_exog))
579+
Z_exog = pt.linalg.block_diag(
580+
*[pt.expand_dims(exog_data, 1) for _ in range(self.k_endog)]
581+
) # (time, k_endog, k_exog)
582+
Z_exog = pt.specify_shape(Z_exog, (None, self.k_endog, self.k_exog * self.k_endog))
583+
# Repeat design_matrix over time dimension
584+
n_timepoints = Z_exog.shape[0]
585+
design_matrix_time = pt.tile(
586+
design_matrix, (n_timepoints, 1, 1)
587+
) # (time, k_endog, states_before_exog)
588+
589+
# Concatenate along states dimension
590+
design_matrix = pt.concatenate([design_matrix_time, Z_exog], axis=2)
591+
592+
self.ssm["design"] = design_matrix
508593

509594
# Transition matrix
510595
# auxiliary function to build transition matrix block
@@ -584,6 +669,8 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
584669
transition_blocks.append(
585670
build_independent_var_block_matrix(error_ar, self.k_endog, self.error_order)
586671
)
672+
if self._exog:
673+
transition_blocks.append(pt.eye(self.k_exog * self.k_endog, dtype=floatX))
587674

588675
# Final block diagonal transition matrix
589676
self.ssm["transition", :, :] = pt.linalg.block_diag(*transition_blocks)
@@ -598,6 +685,8 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
598685
col = self.k_factors + i
599686
self.ssm["selection", row, col] = 1.0
600687

688+
# No changes in selection matrix since there are not innovations related to the betas parameters
689+
601690
factor_cov = pt.eye(self.k_factors, dtype=floatX)
602691

603692
# Handle error_sigma and error_cov depending on error_cov_type

pymc_extras/statespace/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ETS_SEASONAL_DIM = "seasonal_lag"
1515
FACTOR_DIM = "factor"
1616
ERROR_AR_PARAM_DIM = "error_lag_ar"
17+
EXOG_STATE_DIM = "exogenous_state"
1718

1819
NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
1920
VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]

0 commit comments

Comments
 (0)