Skip to content

Commit 4b6f804

Browse files
committed
Harmonizing names for EXOG dimension between DFM and VARMAX
1 parent a329450 commit 4b6f804

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
ALL_STATE_AUX_DIM,
1515
ALL_STATE_DIM,
1616
AR_PARAM_DIM,
17-
EXOGENOUS_DIM,
17+
EXOG_STATE_DIM,
1818
MA_PARAM_DIM,
1919
OBS_STATE_AUX_DIM,
2020
OBS_STATE_DIM,
@@ -342,15 +342,15 @@ def data_info(self) -> dict[str, dict[str, Any]]:
342342
if isinstance(self.exog_state_names, list):
343343
info = {
344344
"exogenous_data": {
345-
"dims": (TIME_DIM, EXOGENOUS_DIM),
345+
"dims": (TIME_DIM, EXOG_STATE_DIM),
346346
"shape": (None, self.k_exog),
347347
}
348348
}
349349

350350
elif isinstance(self.exog_state_names, dict):
351351
info = {
352352
f"{endog_state}_exogenous_data": {
353-
"dims": (TIME_DIM, f"{EXOGENOUS_DIM}_{endog_state}"),
353+
"dims": (TIME_DIM, f"{EXOG_STATE_DIM}_{endog_state}"),
354354
"shape": (None, len(exog_names)),
355355
}
356356
for endog_state, exog_names in self.exog_state_names.items()
@@ -399,10 +399,10 @@ def coords(self) -> dict[str, Sequence]:
399399
coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))})
400400

401401
if isinstance(self.exog_state_names, list):
402-
coords[EXOGENOUS_DIM] = self.exog_state_names
402+
coords[EXOG_STATE_DIM] = self.exog_state_names
403403
elif isinstance(self.exog_state_names, dict):
404404
for name, exog_names in self.exog_state_names.items():
405-
coords[f"{EXOGENOUS_DIM}_{name}"] = exog_names
405+
coords[f"{EXOG_STATE_DIM}_{name}"] = exog_names
406406

407407
return coords
408408

@@ -428,12 +428,12 @@ def param_dims(self):
428428
del coord_map["x0"]
429429

430430
if isinstance(self.exog_state_names, list):
431-
coord_map["beta_exog"] = (OBS_STATE_DIM, EXOGENOUS_DIM)
431+
coord_map["beta_exog"] = (OBS_STATE_DIM, EXOG_STATE_DIM)
432432
elif isinstance(self.exog_state_names, dict):
433433
# If each state has its own exogenous variables, each parameter needs it own dim, since we expect the
434434
# dim labels to all be different (otherwise we'd be in the list case).
435435
for name in self.exog_state_names.keys():
436-
coord_map[f"beta_{name}"] = (f"{EXOGENOUS_DIM}_{name}",)
436+
coord_map[f"beta_{name}"] = (f"{EXOG_STATE_DIM}_{name}",)
437437

438438
return coord_map
439439

0 commit comments

Comments
 (0)