Skip to content

Commit 604f776

Browse files
Allow exogenous variables in BayesianVARMAX
1 parent ff44be5 commit 604f776

File tree

2 files changed

+280
-2
lines changed

2 files changed

+280
-2
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
ALL_STATE_AUX_DIM,
1515
ALL_STATE_DIM,
1616
AR_PARAM_DIM,
17+
EXOGENOUS_DIM,
1718
MA_PARAM_DIM,
1819
OBS_STATE_AUX_DIM,
1920
OBS_STATE_DIM,
2021
SHOCK_AUX_DIM,
2122
SHOCK_DIM,
23+
TIME_DIM,
2224
)
2325

2426
floatX = pytensor.config.floatX
@@ -264,6 +266,14 @@ def param_names(self):
264266
names.remove("ar_params")
265267
if self.q == 0:
266268
names.remove("ma_params")
269+
270+
# Add exogenous regression coefficents rather than remove, since we might have to handle
271+
# several (if self.exog_state_names is a dict)
272+
if isinstance(self.exog_state_names, list):
273+
names.append("beta_exog")
274+
elif isinstance(self.exog_state_names, dict):
275+
names.extend([f"beta_{name}" for name in self.exog_state_names.keys()])
276+
267277
return names
268278

269279
@property
@@ -295,11 +305,49 @@ def param_info(self) -> dict[str, dict[str, Any]]:
295305
},
296306
}
297307

308+
if isinstance(self.exog_state_names, list):
309+
k_exog = len(self.exog_state_names)
310+
info["beta_exog"] = {
311+
"shape": (self.k_endog, k_exog),
312+
"constraints": "None",
313+
}
314+
315+
elif isinstance(self.exog_state_names, dict):
316+
for name, exog_names in self.exog_state_names.items():
317+
k_exog = len(exog_names)
318+
info[f"beta_{name}"] = {
319+
"shape": (k_exog,),
320+
"constraints": "None",
321+
}
322+
298323
for name in self.param_names:
299324
info[name]["dims"] = self.param_dims[name]
300325

301326
return {name: info[name] for name in self.param_names}
302327

328+
@property
329+
def data_info(self) -> dict[str, dict[str, Any]]:
330+
info = None
331+
332+
if isinstance(self.exog_state_names, list):
333+
info = {
334+
"exogenous_data": {
335+
"dims": (TIME_DIM, EXOGENOUS_DIM),
336+
"shape": (None, self.k_exog),
337+
}
338+
}
339+
340+
elif isinstance(self.exog_state_names, dict):
341+
info = {
342+
f"{endog_state}_exogenous_data": {
343+
"dims": (TIME_DIM, f"{EXOGENOUS_DIM}_{endog_state}"),
344+
"shape": (None, len(exog_names)),
345+
}
346+
for endog_state, exog_names in self.exog_state_names.items()
347+
}
348+
349+
return info
350+
303351
@property
304352
def data_names(self) -> list[str]:
305353
if isinstance(self.exog_state_names, list):
@@ -312,10 +360,10 @@ def data_names(self) -> list[str]:
312360
def state_names(self):
313361
state_names = self.endog_names.copy()
314362
state_names += [
315-
f"L{i + 1}.{state}" for i in range(self.p - 1) for state in self.endog_names
363+
f"L{i + 1}_{state}" for i in range(self.p - 1) for state in self.endog_names
316364
]
317365
state_names += [
318-
f"L{i + 1}.{state}_innov" for i in range(self.q) for state in self.endog_names
366+
f"L{i + 1}_{state}_innov" for i in range(self.q) for state in self.endog_names
319367
]
320368

321369
return state_names
@@ -340,6 +388,12 @@ def coords(self) -> dict[str, Sequence]:
340388
if self.q > 0:
341389
coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))})
342390

391+
if isinstance(self.exog_state_names, list):
392+
coords[EXOGENOUS_DIM] = self.exog_state_names
393+
elif isinstance(self.exog_state_names, dict):
394+
for name, exog_names in self.exog_state_names.items():
395+
coords[f"{EXOGENOUS_DIM}_{name}"] = exog_names
396+
343397
return coords
344398

345399
@property
@@ -363,6 +417,14 @@ def param_dims(self):
363417
del coord_map["P0"]
364418
del coord_map["x0"]
365419

420+
if isinstance(self.exog_state_names, list):
421+
coord_map["beta_exog"] = (OBS_STATE_DIM, EXOGENOUS_DIM)
422+
elif isinstance(self.exog_state_names, dict):
423+
# If each state has its own exogenous variables, each parameter needs it own dim, since we expect the
424+
# dim labels to all be different (otherwise we'd be in the list case).
425+
for name in self.exog_state_names.keys():
426+
coord_map[f"beta_{name}"] = (f"{EXOGENOUS_DIM}_{name}",)
427+
366428
return coord_map
367429

368430
def add_default_priors(self):
@@ -450,6 +512,61 @@ def make_symbolic_graph(self) -> None:
450512
)
451513
self.ssm["state_cov", :, :] = state_cov
452514

515+
if self.exog_state_names is not None:
516+
if isinstance(self.exog_state_names, list):
517+
beta_exog = self.make_and_register_variable(
518+
"beta_exog", shape=(self.k_posdef, self.k_exog), dtype=floatX
519+
)
520+
exog_data = self.make_and_register_data(
521+
"exogenous_data", shape=(None, self.k_exog), dtype=floatX
522+
)
523+
524+
obs_intercept = exog_data @ beta_exog.T
525+
526+
elif isinstance(self.exog_state_names, dict):
527+
obs_components = []
528+
for i, name in enumerate(self.endog_names):
529+
if name in self.exog_state_names:
530+
k_exog = len(self.exog_state_names[name])
531+
beta_exog = self.make_and_register_variable(
532+
f"beta_{name}", shape=(k_exog,), dtype=floatX
533+
)
534+
exog_data = self.make_and_register_data(
535+
f"{name}_exogenous_data", shape=(None, k_exog), dtype=floatX
536+
)
537+
obs_components.append(pt.expand_dims(exog_data @ beta_exog, axis=-1))
538+
else:
539+
obs_components.append(pt.zeros((1, 1), dtype=floatX))
540+
541+
# TODO: Replace all of this with pt.concat_with_broadcast once PyMC works with pytensor >= 2.32
542+
543+
# If there were any zeros, they need to be broadcast against the non-zeros.
544+
# Core shape is the last dim, the time dim is always broadcast
545+
non_concat_shape = [1, None]
546+
547+
# Look for the first non-zero component to get the shape from
548+
for tensor_inp in obs_components:
549+
for i, (bcast, sh) in enumerate(
550+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
551+
):
552+
if bcast or i == 1:
553+
continue
554+
non_concat_shape[i] = sh
555+
556+
assert non_concat_shape.count(None) == 1
557+
558+
bcast_tensor_inputs = []
559+
for tensor_inp in obs_components:
560+
non_concat_shape[1] = tensor_inp.shape[1]
561+
bcast_tensor_inputs.append(pt.broadcast_to(tensor_inp, non_concat_shape))
562+
563+
obs_intercept = pt.join(1, *bcast_tensor_inputs)
564+
565+
else:
566+
raise NotImplementedError()
567+
568+
self.ssm["obs_intercept"] = obs_intercept
569+
453570
if self.stationary_initialization:
454571
# Solve for matrix quadratic for P0
455572
T = self.ssm["transition"]

tests/statespace/models/test_VARMAX.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import statsmodels.api as sm
1010

1111
from numpy.testing import assert_allclose, assert_array_less
12+
from pymc.model.transform.optimization import freeze_dims_and_data
1213

1314
from pymc_extras.statespace import BayesianVARMAX
1415
from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG
@@ -203,6 +204,10 @@ def test_create_varmax_with_exogenous(data):
203204
assert mod.k_exog == 2
204205
assert mod.exog_state_names == ["exogenous_0", "exogenous_1"]
205206
assert mod.data_names == ["exogenous_data"]
207+
assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous")
208+
assert mod.coords["exogenous"] == ["exogenous_0", "exogenous_1"]
209+
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
210+
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
206211

207212
# Case 2: exog_state_names as list, k_exog is None
208213
mod = BayesianVARMAX(
@@ -216,6 +221,10 @@ def test_create_varmax_with_exogenous(data):
216221
assert mod.k_exog == 2
217222
assert mod.exog_state_names == ["foo", "bar"]
218223
assert mod.data_names == ["exogenous_data"]
224+
assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous")
225+
assert mod.coords["exogenous"] == ["foo", "bar"]
226+
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
227+
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
219228

220229
# Case 3: k_exog as int, exog_state_names as list (matching)
221230
mod = BayesianVARMAX(
@@ -230,6 +239,10 @@ def test_create_varmax_with_exogenous(data):
230239
assert mod.k_exog == 2
231240
assert mod.exog_state_names == ["a", "b"]
232241
assert mod.data_names == ["exogenous_data"]
242+
assert mod.param_dims["beta_exog"] == ("observed_state", "exogenous")
243+
assert mod.coords["exogenous"] == ["a", "b"]
244+
assert mod.param_info["beta_exog"]["shape"] == (mod.k_endog, 2)
245+
assert mod.param_info["beta_exog"]["dims"] == ("observed_state", "exogenous")
233246

234247
# Case 4: k_exog as dict, exog_state_names is None
235248
k_exog = {"observed_0": 2, "observed_1": 1, "observed_2": 0}
@@ -252,6 +265,25 @@ def test_create_varmax_with_exogenous(data):
252265
"observed_1_exogenous_data",
253266
"observed_2_exogenous_data",
254267
]
268+
assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",)
269+
assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",)
270+
assert (
271+
"beta_observed_2" not in mod.param_dims
272+
or mod.param_info.get("beta_observed_2") is None
273+
or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0
274+
)
275+
276+
assert mod.coords["exogenous_observed_0"] == [
277+
"observed_0_exogenous_0",
278+
"observed_0_exogenous_1",
279+
]
280+
assert mod.coords["exogenous_observed_1"] == ["observed_1_exogenous_0"]
281+
assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == []
282+
283+
assert mod.param_info["beta_observed_0"]["shape"] == (2,)
284+
assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",)
285+
assert mod.param_info["beta_observed_1"]["shape"] == (1,)
286+
assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",)
255287

256288
# Case 5: exog_state_names as dict, k_exog is None
257289
exog_state_names = {"observed_0": ["a", "b"], "observed_1": ["c"], "observed_2": []}
@@ -270,6 +302,22 @@ def test_create_varmax_with_exogenous(data):
270302
"observed_1_exogenous_data",
271303
"observed_2_exogenous_data",
272304
]
305+
assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",)
306+
assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",)
307+
assert (
308+
"beta_observed_2" not in mod.param_dims
309+
or mod.param_info.get("beta_observed_2") is None
310+
or mod.param_info.get("beta_observed_2", {}).get("shape", (0,))[0] == 0
311+
)
312+
313+
assert mod.coords["exogenous_observed_0"] == ["a", "b"]
314+
assert mod.coords["exogenous_observed_1"] == ["c"]
315+
assert "exogenous_observed_2" in mod.coords and mod.coords["exogenous_observed_2"] == []
316+
317+
assert mod.param_info["beta_observed_0"]["shape"] == (2,)
318+
assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",)
319+
assert mod.param_info["beta_observed_1"]["shape"] == (1,)
320+
assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",)
273321

274322
# Case 6: k_exog as dict, exog_state_names as dict (matching)
275323
k_exog = {"observed_0": 2, "observed_1": 1}
@@ -286,6 +334,14 @@ def test_create_varmax_with_exogenous(data):
286334
assert mod.k_exog == k_exog
287335
assert mod.exog_state_names == exog_state_names
288336
assert mod.data_names == ["observed_0_exogenous_data", "observed_1_exogenous_data"]
337+
assert mod.param_dims["beta_observed_0"] == ("exogenous_observed_0",)
338+
assert mod.param_dims["beta_observed_1"] == ("exogenous_observed_1",)
339+
assert mod.coords["exogenous_observed_0"] == ["a", "b"]
340+
assert mod.coords["exogenous_observed_1"] == ["c"]
341+
assert mod.param_info["beta_observed_0"]["shape"] == (2,)
342+
assert mod.param_info["beta_observed_0"]["dims"] == ("exogenous_observed_0",)
343+
assert mod.param_info["beta_observed_1"]["shape"] == (1,)
344+
assert mod.param_info["beta_observed_1"]["dims"] == ("exogenous_observed_1",)
289345

290346
# Error: k_exog as int, exog_state_names as list (length mismatch)
291347
with pytest.raises(
@@ -348,3 +404,108 @@ def test_create_varmax_with_exogenous(data):
348404
measurement_error=False,
349405
stationary_initialization=False,
350406
)
407+
408+
409+
@pytest.mark.parametrize(
410+
"k_exog, exog_state_names",
411+
[
412+
(2, None),
413+
(None, ["foo", "bar"]),
414+
(None, {"y1": ["a", "b"], "y2": ["c"]}),
415+
],
416+
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
417+
)
418+
@pytest.mark.filterwarnings("ignore::UserWarning")
419+
def test_varmax_with_exog(rng, k_exog, exog_state_names):
420+
endog_names = ["y1", "y2", "y3"]
421+
n_obs = 50
422+
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
423+
424+
y = rng.normal(size=(n_obs, len(endog_names)))
425+
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
426+
427+
if isinstance(exog_state_names, dict):
428+
exog_data = {
429+
f"{name}_exogenous_data": pd.DataFrame(
430+
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
431+
columns=exog_names,
432+
index=time_idx,
433+
)
434+
for name, exog_names in exog_state_names.items()
435+
}
436+
else:
437+
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
438+
exog_data = {
439+
"exogenous_data": pd.DataFrame(
440+
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
441+
columns=exog_names,
442+
index=time_idx,
443+
)
444+
}
445+
446+
mod = BayesianVARMAX(
447+
endog_names=endog_names,
448+
order=(1, 0),
449+
k_exog=k_exog,
450+
exog_state_names=exog_state_names,
451+
verbose=True,
452+
measurement_error=False,
453+
stationary_initialization=False,
454+
mode="JAX",
455+
)
456+
457+
with pm.Model(coords=mod.coords) as m:
458+
for var_name, data in exog_data.items():
459+
pm.Data(var_name, data, dims=mod.data_info[var_name]["dims"])
460+
461+
x0 = pm.Deterministic("x0", pt.zeros(mod.k_states), dims=mod.param_dims["x0"])
462+
P0_diag = pm.Exponential("P0_diag", 1.0, dims=mod.param_dims["P0"][0])
463+
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=mod.param_dims["P0"])
464+
465+
ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=mod.param_dims["ar_params"])
466+
state_cov_diag = pm.Exponential("state_cov_diag", 1.0, dims=mod.param_dims["state_cov"][0])
467+
state_cov = pm.Deterministic(
468+
"state_cov", pt.diag(state_cov_diag), dims=mod.param_dims["state_cov"]
469+
)
470+
471+
# Exogenous priors
472+
if isinstance(mod.exog_state_names, list):
473+
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=mod.param_dims["beta_exog"])
474+
elif isinstance(mod.exog_state_names, dict):
475+
for name in mod.exog_state_names:
476+
if mod.exog_state_names.get(name):
477+
pm.Normal(f"beta_{name}", mu=0, sigma=1, dims=mod.param_dims[f"beta_{name}"])
478+
479+
mod.build_statespace_graph(data=df)
480+
481+
with freeze_dims_and_data(m):
482+
prior = pm.sample_prior_predictive(
483+
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
484+
)
485+
486+
prior_cond = mod.sample_conditional_prior(prior, mvn_method="eigh")
487+
beta_dot_data = prior_cond.filtered_prior_observed.values - prior_cond.filtered_prior.values
488+
489+
if isinstance(exog_state_names, list) or k_exog is not None:
490+
beta = prior.prior.beta_exog
491+
assert beta.shape == (1, 10, 3, 2)
492+
493+
np.testing.assert_allclose(
494+
beta_dot_data,
495+
np.einsum("tx,...sx->...ts", exog_data["exogenous_data"].values, beta),
496+
atol=1e-2,
497+
)
498+
499+
elif isinstance(exog_state_names, dict):
500+
assert prior.prior.beta_y1.shape == (1, 10, 2)
501+
assert prior.prior.beta_y2.shape == (1, 10, 1)
502+
503+
obs_intercept = [
504+
np.einsum("tx,...x->...t", exog_data[f"{name}_exogenous_data"].values, beta)
505+
for name, beta in zip(["y1", "y2"], [prior.prior.beta_y1, prior.prior.beta_y2])
506+
]
507+
508+
# y3 has no exogenous variables
509+
obs_intercept.append(np.zeros_like(obs_intercept[0]))
510+
511+
np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2)

0 commit comments

Comments
 (0)