1212 ALL_STATE_DIM ,
1313 AR_PARAM_DIM ,
1414 ERROR_AR_PARAM_DIM ,
15+ EXOG_STATE_DIM ,
1516 FACTOR_DIM ,
1617 OBS_STATE_AUX_DIM ,
1718 OBS_STATE_DIM ,
19+ TIME_DIM ,
1820)
1921
2022floatX = pytensor .config .floatX
@@ -283,8 +285,8 @@ def __init__(
283285 if endog_names is None :
284286 endog_names = [f"endog_{ i } " for i in range (k_endog )]
285287
286- if k_exog is not None or exog_names is not None :
287- raise NotImplementedError ("Exogenous variables (exog) are not yet implemented." )
288+ # if k_exog is not None or exog_names is not None:
289+ # raise NotImplementedError("Exogenous variables (exog) are not yet implemented.")
288290
289291 self .endog_names = endog_names
290292 self .k_endog = k_endog
@@ -293,7 +295,23 @@ def __init__(
293295 self .error_order = error_order
294296 self .error_var = error_var
295297 self .error_cov_type = error_cov_type
296- # TODO add exogenous variables support
298+
299+ if k_exog is None and exog_names is None :
300+ self ._exog = False
301+ self .k_exog = 0
302+ else :
303+ self ._exog = True
304+ if k_exog is None :
305+ k_exog = len (exog_names ) if exog_names is not None else 0
306+ elif exog_names is None :
307+ exog_names = [f"exog_{ i } " for i in range (k_exog )] if k_exog > 0 else None
308+
309+ self .exog_names = exog_names
310+ self .k_exog = k_exog
311+
312+ # TODO add exogenous variables support (statsmodel dealt with exog without touching state vector,but just working on the measurement equation)
313+ # I start implementing a version of exog support, with shared_states=False based on pymc_extras/statespace/models/structural/components/regression.py
314+ # currently the beta coefficients are time invariant, so the innovation on beta are not supported
297315
298316 # Determine the dimension for the latent factor states.
299317 # For static factors, one use k_factors.
@@ -308,11 +326,12 @@ def __init__(
308326 k_error_states = k_endog * error_order if error_order > 0 else 0
309327
310328 # Total state dimension
311- k_states = k_factor_states + k_error_states
329+ k_states = k_factor_states + k_error_states + ( k_exog * k_endog if self . _exog else 0 )
312330
313331 # Number of independent shocks.
314332 # Typically, the latent factors introduce k_factors shocks.
315333 # If error_order > 0 and errors are modeled jointly or separately, add appropriate count.
334+ # TODO currently the implementation does not support for innovation on betas coefficient
316335 k_posdef = k_factors + (k_endog if error_order > 0 else 0 )
317336
318337 # Initialize the PyMCStateSpace base class.
@@ -346,6 +365,8 @@ def param_names(self):
346365 names .append ("error_cov" )
347366 if not self .measurement_error :
348367 names .remove ("sigma_obs" )
368+ if self ._exog :
369+ names .append ("beta" )
349370
350371 return names
351372
@@ -387,6 +408,10 @@ def param_info(self) -> dict[str, dict[str, Any]]:
387408 "shape" : (self .k_endog ,),
388409 "constraints" : "Positive" ,
389410 },
411+ "beta" : {
412+ "shape" : (self .k_exog * self .k_endog if self .k_exog is not None else 0 ,),
413+ "constraints" : None ,
414+ },
390415 }
391416
392417 for name in self .param_names :
@@ -398,7 +423,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
398423 def state_names (self ) -> list [str ]:
399424 """
400425 Returns the names of the hidden states: first factor states (with lags),
401- then idiosyncratic error states (with lags).
426+ idiosyncratic error states (with lags), then exogenous states .
402427 """
403428 names = []
404429 # Factor states
@@ -412,6 +437,12 @@ def state_names(self) -> list[str]:
412437 for lag in range (self .error_order ):
413438 names .append (f"L{ lag } .error_{ i } " )
414439
440+ if self ._exog :
441+ # Exogenous states
442+ for i in range (self .k_exog ):
443+ for j in range (self .k_endog ):
444+ names .append (f"exog_{ i } .endog_{ j } " )
445+
415446 return names
416447
417448 @property
@@ -438,6 +469,10 @@ def coords(self) -> dict[str, Sequence]:
438469 else :
439470 coords [ERROR_AR_PARAM_DIM ] = list (range (1 , self .error_order + 1 ))
440471
472+ if self ._exog :
473+ # Exogenous states
474+ coords [EXOG_STATE_DIM ] = list (range (1 , (self .k_exog * self .k_endog ) + 1 ))
475+
441476 return coords
442477
443478 @property
@@ -479,8 +514,30 @@ def param_dims(self):
479514
480515 if self .measurement_error :
481516 coord_map ["sigma_obs" ] = (OBS_STATE_DIM ,)
517+
518+ if self ._exog :
519+ coord_map ["beta" ] = (EXOG_STATE_DIM ,)
520+ # coord_map["exog_data"]
521+
482522 return coord_map
483523
524+ @property
525+ def data_info (self ):
526+ if self ._exog :
527+ return {
528+ "exog_data" : {
529+ "shape" : (None , self .k_exog ),
530+ "dims" : (TIME_DIM , EXOG_STATE_DIM ),
531+ },
532+ }
533+ return {}
534+
535+ @property
536+ def data_names (self ):
537+ if self ._exog :
538+ return ["exog_data" ]
539+ return []
540+
484541 def make_symbolic_graph (self ):
485542 # Initial states
486543 x0 = self .make_and_register_variable ("x0" , shape = (self .k_states ,), dtype = floatX )
@@ -498,13 +555,41 @@ def make_symbolic_graph(self):
498555 "factor_loadings" , shape = (self .k_endog , self .k_factors ), dtype = floatX
499556 )
500557
501- for i in range (self .k_factors ):
502- self .ssm ["design" , :, i ] = factor_loadings [:, i ]
558+ # Start with factor loadings
559+ matrix_parts = [factor_loadings ]
560+
561+ if self .factor_order > 1 :
562+ matrix_parts .append (
563+ pt .zeros ((self .k_endog , self .k_factors * (self .factor_order - 1 )), dtype = floatX )
564+ )
503565
504566 if self .error_order > 0 :
505- for i in range (self .k_endog ):
506- col_idx = max (self .factor_order , 1 ) * self .k_factors + i
507- self .ssm ["design" , i , col_idx ] = 1.0
567+ # Create identity matrix for error terms
568+ error_matrix = pt .eye (self .k_endog , dtype = floatX )
569+ matrix_parts .append (error_matrix )
570+ matrix_parts .append (
571+ pt .zeros ((self .k_endog , self .k_endog * (self .error_order - 1 )), dtype = floatX )
572+ )
573+
574+ # Concatenate all parts
575+ design_matrix = pt .concatenate (matrix_parts , axis = 1 )
576+
577+ if self ._exog :
578+ exog_data = self .make_and_register_data ("exog_data" , shape = (None , self .k_exog ))
579+ Z_exog = pt .linalg .block_diag (
580+ * [pt .expand_dims (exog_data , 1 ) for _ in range (self .k_endog )]
581+ ) # (time, k_endog, k_exog)
582+ Z_exog = pt .specify_shape (Z_exog , (None , self .k_endog , self .k_exog * self .k_endog ))
583+ # Repeat design_matrix over time dimension
584+ n_timepoints = Z_exog .shape [0 ]
585+ design_matrix_time = pt .tile (
586+ design_matrix , (n_timepoints , 1 , 1 )
587+ ) # (time, k_endog, states_before_exog)
588+
589+ # Concatenate along states dimension
590+ design_matrix = pt .concatenate ([design_matrix_time , Z_exog ], axis = 2 )
591+
592+ self .ssm ["design" ] = design_matrix
508593
509594 # Transition matrix
510595 # auxiliary function to build transition matrix block
@@ -584,6 +669,8 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
584669 transition_blocks .append (
585670 build_independent_var_block_matrix (error_ar , self .k_endog , self .error_order )
586671 )
672+ if self ._exog :
673+ transition_blocks .append (pt .eye (self .k_exog * self .k_endog , dtype = floatX ))
587674
588675 # Final block diagonal transition matrix
589676 self .ssm ["transition" , :, :] = pt .linalg .block_diag (* transition_blocks )
@@ -598,6 +685,8 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
598685 col = self .k_factors + i
599686 self .ssm ["selection" , row , col ] = 1.0
600687
688+ # No changes in selection matrix since there are not innovations related to the betas parameters
689+
601690 factor_cov = pt .eye (self .k_factors , dtype = floatX )
602691
603692 # Handle error_sigma and error_cov depending on error_cov_type
0 commit comments