Skip to content

Commit 4d0b027

Browse files
Use constant for exogenous dim
1 parent 6a6304b commit 4d0b027

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ALL_STATE_AUX_DIM,
1818
ALL_STATE_DIM,
1919
AR_PARAM_DIM,
20+
EXOGENOUS_DIM,
2021
MA_PARAM_DIM,
2122
OBS_STATE_DIM,
2223
SARIMAX_STATE_STRUCTURES,
@@ -314,7 +315,7 @@ def param_names(self):
314315
def data_info(self) -> dict[str, dict[str, Any]]:
315316
info = {
316317
"exogenous_data": {
317-
"dims": (TIME_DIM, "exogenous"),
318+
"dims": (TIME_DIM, EXOGENOUS_DIM),
318319
"shape": (None, self.k_exog),
319320
}
320321
}
@@ -350,7 +351,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
350351
},
351352
"seasonal_ar_params": {"shape": (self.P,), "constraints": "None"},
352353
"seasonal_ma_params": {"shape": (self.Q,), "constraints": "None"},
353-
"beta_exog": {"shape": (self.k_exog,), "constraints": "None", "dims": ("exogenous",)},
354+
"beta_exog": {"shape": (self.k_exog,), "constraints": "None"},
354355
}
355356

356357
for name in self.param_names:
@@ -402,7 +403,7 @@ def param_dims(self):
402403
"ma_params": (MA_PARAM_DIM,),
403404
"seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
404405
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
405-
"beta_exog": ("exogenous",),
406+
"beta_exog": (EXOGENOUS_DIM,),
406407
}
407408
if self.k_endog == 1:
408409
coord_map["sigma_state"] = None
@@ -437,7 +438,7 @@ def coords(self) -> dict[str, Sequence]:
437438
if self.Q > 0:
438439
coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
439440
if self.k_exog > 0:
440-
coords.update({"exogenous": self.exog_state_names})
441+
coords.update({EXOGENOUS_DIM: self.exog_state_names})
441442
return coords
442443

443444
def _stationary_initialization(self):

pymc_extras/statespace/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SEASONAL_AR_PARAM_DIM = "seasonal_lag_ar"
1313
SEASONAL_MA_PARAM_DIM = "seasonal_lag_ma"
1414
ETS_SEASONAL_DIM = "seasonal_lag"
15+
EXOGENOUS_DIM = "exogenous"
1516

1617
NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
1718
VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]

0 commit comments

Comments
 (0)