Skip to content

Commit 31cba0b

Browse files
Eagerly simplify model where possible
1 parent 604f776 commit 31cba0b

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,16 @@ def __init__(
223223
elif isinstance(exog_state_names, dict):
224224
k_exog = {name: len(names) for name, names in exog_state_names.items()}
225225

226+
# If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
227+
# then we can drop back to the list case.
228+
if (
229+
isinstance(exog_state_names, dict)
230+
and set(exog_state_names.keys()) == set(endog_names)
231+
and len({frozenset(val) for val in exog_state_names.values()}) == 1
232+
):
233+
exog_state_names = exog_state_names[endog_names[0]]
234+
k_exog = len(exog_state_names)
235+
226236
self.endog_names = list(endog_names)
227237
self.exog_state_names = exog_state_names
228238

tests/statespace/models/test_VARMAX.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ def test_impulse_response(parameters, varma_mod, idata, rng):
191191
assert not np.any(np.isnan(irf.irf.values))
192192

193193

194-
def test_create_varmax_with_exogenous(data):
195-
# Case 1: k_exog as int, exog_state_names is None
194+
def test_create_varmax_with_exogenous_k_exog_int(data):
196195
mod = BayesianVARMAX(
197196
k_endog=data.shape[1],
198197
order=(1, 0),
@@ -209,7 +208,8 @@ def test_create_varmax_with_exogenous(data):
209208
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
210209
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
211210

212-
# Case 2: exog_state_names as list, k_exog is None
211+
212+
def test_create_varmax_with_exogenous_list_of_names(data):
213213
mod = BayesianVARMAX(
214214
k_endog=data.shape[1],
215215
order=(1, 0),
@@ -226,7 +226,8 @@ def test_create_varmax_with_exogenous(data):
226226
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
227227
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
228228

229-
# Case 3: k_exog as int, exog_state_names as list (matching)
229+
230+
def test_create_varmax_with_exogenous_both_defined_correctly(data):
230231
mod = BayesianVARMAX(
231232
k_endog=data.shape[1],
232233
order=(1, 0),
@@ -244,7 +245,8 @@ def test_create_varmax_with_exogenous(data):
244245
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
245246
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
246247

247-
# Case 4: k_exog as dict, exog_state_names is None
248+
249+
def test_create_varmax_with_exogenous_k_exog_dict(data):
248250
k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0}
249251
mod = BayesianVARMAX(
250252
endog_names=["observed_0", "observed_1", "observed_2"],
@@ -285,7 +287,8 @@ def test_create_varmax_with_exogenous(data):
285287
assert mod.param_info["beta_observed_1"]["shape"] == (1,)
286288
assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",)
287289

288-
# Case 5: exog_state_names as dict, k_exog is None
290+
291+
def test_create_varmax_with_exogenous_exog_names_dict(data):
289292
exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []}
290293
mod = BayesianVARMAX(
291294
endog_names=["observed_0", "observed_1", "observed_2"],
@@ -319,7 +322,8 @@ def test_create_varmax_with_exogenous(data):
319322
assert mod.param_info["beta_observed_1"]["shape"] == (1,)
320323
assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",)
321324

322-
# Case 6: k_exog as dict, exog_state_names as dict (matching)
325+
326+
def test_create_varmax_with_exogenous_both_dict_correct(data):
323327
k_exog = {"observed_0": 2, "observed_1": 1}
324328
exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"]}
325329
mod = BayesianVARMAX(
@@ -343,7 +347,33 @@ def test_create_varmax_with_exogenous(data):
343347
assert mod.param_info["beta_observed_1"]["shape"] == (1,)
344348
assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",)
345349

346-
# Error: k_exog as int, exog_state_names as list (length mismatch)
350+
351+
def test_create_varmax_with_exogenous_dict_converts_to_list(data):
352+
exog_state_names = {
353+
"observed_0": ["a", "b"],
354+
"observed_1": ["a", "b"],
355+
"observed_2": ["a", "b"],
356+
}
357+
mod = BayesianVARMAX(
358+
endog_names=["observed_0", "observed_1", "observed_2"],
359+
order=(1, 0),
360+
exog_state_names=exog_state_names,
361+
verbose=False,
362+
measurement_error=False,
363+
stationary_initialization=False,
364+
)
365+
366+
assert mod.k_exog == 2
367+
assert mod.exog_state_names == ["a", "b"]
368+
assert mod.data_names == ["exogenous_data"]
369+
assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous")
370+
assert mod.coords["exogenous"] == ["a", "b"]
371+
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
372+
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
373+
374+
375+
def test_create_varmax_with_exogenous_raises_if_args_disagree(data):
376+
# List case
347377
with pytest.raises(
348378
ValueError, match="Length of exog_state_names does not match provided k_exog"
349379
):
@@ -357,8 +387,10 @@ def test_create_varmax_with_exogenous(data):
357387
stationary_initialization=False,
358388
)
359389

360-
# Error: k_exog as int, exog_state_names as dict
361-
with pytest.raises(ValueError):
390+
# Dict case
391+
with pytest.raises(
392+
ValueError, match="If k_exog is an int, exog_state_names must be a list of the same length"
393+
):
362394
BayesianVARMAX(
363395
k_endog=2,
364396
order=(1, 0),
@@ -369,8 +401,10 @@ def test_create_varmax_with_exogenous(data):
369401
stationary_initialization=False,
370402
)
371403

372-
# Error: k_exog as dict, exog_state_names as list
373-
with pytest.raises(ValueError):
404+
# dict + list
405+
with pytest.raises(
406+
ValueError, match="If k_exog is a dict, exog_state_names must be a dict as well"
407+
):
374408
BayesianVARMAX(
375409
endog_names=["observed_0", "observed_1"],
376410
order=(1, 0),
@@ -381,7 +415,7 @@ def test_create_varmax_with_exogenous(data):
381415
stationary_initialization=False,
382416
)
383417

384-
# Error: k_exog as dict, exog_state_names as dict (keys mismatch)
418+
# Dict/dict, key mismatch
385419
with pytest.raises(ValueError, match="Keys of k_exog and exog_state_names dicts must match"):
386420
BayesianVARMAX(
387421
endog_names=["observed_0", "observed_1"],
@@ -393,7 +427,7 @@ def test_create_varmax_with_exogenous(data):
393427
stationary_initialization=False,
394428
)
395429

396-
# Error: k_exog as dict, exog_state_names as dict (length mismatch)
430+
# Dict/dict, length mismatch
397431
with pytest.raises(ValueError, match="lengths of exog_state_names lists must match"):
398432
BayesianVARMAX(
399433
endog_names=["observed_0", "observed_1"],

0 commit comments

Comments
 (0)