|
17 | 17 | ALL_STATE_AUX_DIM,
|
18 | 18 | ALL_STATE_DIM,
|
19 | 19 | AR_PARAM_DIM,
|
| 20 | + EXOGENOUS_DIM, |
20 | 21 | MA_PARAM_DIM,
|
21 | 22 | OBS_STATE_DIM,
|
22 | 23 | SARIMAX_STATE_STRUCTURES,
|
@@ -314,7 +315,7 @@ def param_names(self):
|
314 | 315 | def data_info(self) -> dict[str, dict[str, Any]]:
|
315 | 316 | info = {
|
316 | 317 | "exogenous_data": {
|
317 |
| - "dims": (TIME_DIM, "exogenous"), |
| 318 | + "dims": (TIME_DIM, EXOGENOUS_DIM), |
318 | 319 | "shape": (None, self.k_exog),
|
319 | 320 | }
|
320 | 321 | }
|
@@ -350,7 +351,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
|
350 | 351 | },
|
351 | 352 | "seasonal_ar_params": {"shape": (self.P,), "constraints": "None"},
|
352 | 353 | "seasonal_ma_params": {"shape": (self.Q,), "constraints": "None"},
|
353 |
| - "beta_exog": {"shape": (self.k_exog,), "constraints": "None", "dims": ("exogenous",)}, |
| 354 | + "beta_exog": {"shape": (self.k_exog,), "constraints": "None"}, |
354 | 355 | }
|
355 | 356 |
|
356 | 357 | for name in self.param_names:
|
@@ -402,7 +403,7 @@ def param_dims(self):
|
402 | 403 | "ma_params": (MA_PARAM_DIM,),
|
403 | 404 | "seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
|
404 | 405 | "seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
|
405 |
| - "beta_exog": ("exogenous",), |
| 406 | + "beta_exog": (EXOGENOUS_DIM,), |
406 | 407 | }
|
407 | 408 | if self.k_endog == 1:
|
408 | 409 | coord_map["sigma_state"] = None
|
@@ -437,7 +438,7 @@ def coords(self) -> dict[str, Sequence]:
|
437 | 438 | if self.Q > 0:
|
438 | 439 | coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
|
439 | 440 | if self.k_exog > 0:
|
440 |
| - coords.update({"exogenous": self.exog_state_names}) |
| 441 | + coords.update({EXOGENOUS_DIM: self.exog_state_names}) |
441 | 442 | return coords
|
442 | 443 |
|
443 | 444 | def _stationary_initialization(self):
|
|
0 commit comments