Skip to content

Commit 0e78a6e

Browse files
More robust needs_exog_data
1 parent b8d2e4e commit 0e78a6e

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def __init__(
172172
if len(endog_names) != k_endog:
173173
raise ValueError("Length of provided endog_names does not match provided k_endog")
174174

175+
needs_exog_data = False
176+
175177
if k_exog is not None and not isinstance(k_exog, int | dict):
176178
raise ValueError("If not None, k_endog must be either an int or a dict")
177179
if exog_state_names is not None and not isinstance(exog_state_names, list | dict):
@@ -199,6 +201,7 @@ def __init__(
199201
"If both k_endog and exog_state_names are provided, lengths of exog_state_names "
200202
"lists must match corresponding values in k_exog"
201203
)
204+
needs_exog_data = True
202205

203206
if k_exog is not None and exog_state_names is None:
204207
if isinstance(k_exog, int):
@@ -207,12 +210,14 @@ def __init__(
207210
exog_state_names = {
208211
name: [f"{name}_exogenous_{i}" for i in range(k)] for name, k in k_exog.items()
209212
}
213+
needs_exog_data = True
210214

211215
if k_exog is None and exog_state_names is not None:
212216
if isinstance(exog_state_names, list):
213217
k_exog = len(exog_state_names)
214218
elif isinstance(exog_state_names, dict):
215219
k_exog = {name: len(names) for name, names in exog_state_names.items()}
220+
needs_exog_data = True
216221

217222
# If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
218223
# then we can drop back to the list case.
@@ -245,7 +250,7 @@ def __init__(
245250
mode=mode,
246251
)
247252

248-
self._needs_exog_data = exog_state_names is not None and len(exog_state_names) > 0
253+
self._needs_exog_data = needs_exog_data
249254

250255
# Save counts of the number of parameters in each category
251256
self.param_counts = {

tests/statespace/models/test_VARMAX.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ def test_forecast_with_exog(self, rng):
580580
},
581581
)
582582

583+
assert mod._needs_exog_data
584+
583585
with freeze_dims_and_data(m):
584586
prior = pm.sample_prior_predictive(
585587
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}

0 commit comments

Comments
 (0)