Skip to content

Commit c694646

Browse files
Regression component bugfix and tests
1 parent a102e3c commit c694646

File tree

3 files changed

+93
-18
lines changed

3 files changed

+93
-18
lines changed

pymc_extras/statespace/models/structural/components/regression.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class RegressionComponent(Component):
1010
def __init__(
1111
self,
1212
k_exog: int | None = None,
13-
name: str | None = "Exogenous",
13+
name: str | None = "regression",
1414
state_names: list[str] | None = None,
1515
observed_state_names: list[str] | None = None,
1616
innovations=False,
@@ -61,38 +61,38 @@ def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -
6161
def make_symbolic_graph(self) -> None:
6262
k_endog = self.k_endog
6363
k_states = self.k_states // k_endog
64-
self.k_posdef // k_endog
6564

66-
betas = self.make_and_register_variable(f"beta_{self.name}", shape=(k_endog, k_states))
65+
betas = self.make_and_register_variable(
66+
f"beta_{self.name}", shape=(k_endog, k_states) if k_endog > 1 else (k_states,)
67+
)
6768
regression_data = self.make_and_register_data(f"data_{self.name}", shape=(None, k_states))
6869

69-
self.ssm["initial_state", :] = betas.reshape((1, -1)).squeeze()
70-
T = np.eye(k_states)
71-
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
72-
self.ssm["selection", :, :] = np.eye(self.k_states)
70+
self.ssm["initial_state", :] = betas.ravel()
71+
self.ssm["transition", :, :] = pt.eye(self.k_states)
72+
self.ssm["selection", :, :] = pt.eye(self.k_states)
73+
7374
Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)])
7475
self.ssm["design"] = pt.specify_shape(
7576
Z, (None, k_endog, regression_data.type.shape[1] * k_endog)
7677
)
7778

7879
if self.innovations:
7980
sigma_beta = self.make_and_register_variable(
80-
f"sigma_beta_{self.name}", (self.k_states,)
81+
f"sigma_beta_{self.name}", (k_states,) if k_endog == 1 else (k_endog, k_states)
8182
)
8283
row_idx, col_idx = np.diag_indices(self.k_states)
83-
self.ssm["state_cov", row_idx, col_idx] = sigma_beta**2
84+
self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2
8485

8586
def populate_component_properties(self) -> None:
8687
k_endog = self.k_endog
8788
k_states = self.k_states // k_endog
88-
self.k_posdef // k_endog
8989

9090
self.shock_names = self.state_names
9191

9292
self.param_names = [f"beta_{self.name}"]
9393
self.data_names = [f"data_{self.name}"]
9494
self.param_dims = {
95-
f"beta_{self.name}": ("exog_endog", "exog_state"),
95+
f"beta_{self.name}": (f"{self.name}_endog", f"{self.name}_state"),
9696
}
9797

9898
base_names = self.state_names
@@ -102,9 +102,11 @@ def populate_component_properties(self) -> None:
102102

103103
self.param_info = {
104104
f"beta_{self.name}": {
105-
"shape": (k_endog, k_states),
105+
"shape": (k_endog, k_states) if k_endog > 1 else (k_states,),
106106
"constraints": None,
107-
"dims": ("exog_endog", "exog_state"),
107+
"dims": (f"{self.name}_endog", f"{self.name}_state")
108+
if k_endog > 1
109+
else (f"{self.name}_state",),
108110
},
109111
}
110112

@@ -114,13 +116,18 @@ def populate_component_properties(self) -> None:
114116
"dims": (TIME_DIM, "exog_state"),
115117
},
116118
}
117-
self.coords = {"exog_state": base_names, "exog_endog": self.observed_state_names}
119+
self.coords = {
120+
f"{self.name}_state": self.state_names,
121+
f"{self.name}_endog": self.observed_state_names,
122+
}
118123

119124
if self.innovations:
120125
self.param_names += [f"sigma_beta_{self.name}"]
121-
self.param_dims[f"sigma_beta_{self.name}"] = "exog_state"
126+
self.param_dims[f"sigma_beta_{self.name}"] = f"{self.name}_state"
122127
self.param_info[f"sigma_beta_{self.name}"] = {
123128
"shape": (),
124129
"constraints": "Positive",
125-
"dims": ("exog_state",),
130+
"dims": (f"{self.name}_state",)
131+
if k_endog == 1
132+
else (f"{self.name}_endog", f"{self.name}_state"),
126133
}

tests/statespace/models/structural/components/test_regression.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import assert_allclose
66
from pytensor import config
77
from pytensor import tensor as pt
8+
from pytensor.graph.basic import explicit_graph_inputs
89

910
from pymc_extras.statespace.models import structural as st
1011
from tests.statespace.models.structural.conftest import _assert_basic_coords_correct
@@ -25,7 +26,7 @@ def test_exogenous_component(rng):
2526
# Check that the generated data is just a linear regression
2627
assert_allclose(y, data @ params["beta_exog"], atol=ATOL, rtol=RTOL)
2728

28-
mod.build(verbose=False)
29+
mod = mod.build(verbose=False)
2930
_assert_basic_coords_correct(mod)
3031
assert mod.coords["exog_state"] == ["feature_1", "feature_2"]
3132

@@ -42,6 +43,73 @@ def test_adding_exogenous_component(rng):
4243
assert_allclose(mod.ssm["design", 5, 0, :2].eval({"data_exog": data}), data[5])
4344

4445

46+
def test_regression_with_multiple_observed_states(rng):
47+
from scipy.linalg import block_diag
48+
49+
data = rng.normal(size=(100, 2)).astype(config.floatX)
50+
mod = st.RegressionComponent(
51+
state_names=["feature_1", "feature_2"],
52+
name="exog",
53+
observed_state_names=["data_1", "data_2"],
54+
)
55+
56+
params = {"beta_exog": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)}
57+
exog_data = {"data_exog": data}
58+
x, y = simulate_from_numpy_model(mod, rng, params, exog_data)
59+
60+
assert x.shape == (100, 4) # 2 features, 2 states
61+
assert y.shape == (100, 2)
62+
63+
# Check that the generated data are two independent linear regressions
64+
assert_allclose(y[:, 0], data @ params["beta_exog"][0], atol=ATOL, rtol=RTOL)
65+
assert_allclose(y[:, 1], data @ params["beta_exog"][1], atol=ATOL, rtol=RTOL)
66+
67+
mod = mod.build(verbose=False)
68+
assert mod.coords["exog_state"] == [
69+
"feature_1[data_1]",
70+
"feature_2[data_1]",
71+
"feature_1[data_2]",
72+
"feature_2[data_2]",
73+
]
74+
75+
Z = mod.ssm["design"].eval({"data_exog": data})
76+
vec_block_diag = np.vectorize(block_diag, signature="(n,m),(o,p)->(q,r)")
77+
assert Z.shape == (100, 2, 4)
78+
assert np.allclose(Z, vec_block_diag(data[:, None, :], data[:, None, :]))
79+
80+
81+
def test_add_regression_components_with_multiple_observed_states(rng):
82+
from scipy.linalg import block_diag
83+
84+
data_1 = rng.normal(size=(100, 2)).astype(config.floatX)
85+
data_2 = rng.normal(size=(100, 1)).astype(config.floatX)
86+
87+
reg1 = st.RegressionComponent(
88+
state_names=["a", "b"], name="exog1", observed_state_names=["data_1", "data_2"]
89+
)
90+
reg2 = st.RegressionComponent(state_names=["c"], name="exog2", observed_state_names=["data_3"])
91+
92+
mod = (reg1 + reg2).build(verbose=False)
93+
assert mod.coords["exog1_state"] == ["a[data_1]", "b[data_1]", "a[data_2]", "b[data_2]"]
94+
assert mod.coords["exog2_state"] == ["c[data_3]"]
95+
96+
Z = mod.ssm["design"].eval({"data_exog1": data_1, "data_exog2": data_2})
97+
vec_block_diag = np.vectorize(block_diag, signature="(n,m),(o,p)->(q,r)")
98+
assert Z.shape == (100, 3, 5)
99+
assert np.allclose(
100+
Z,
101+
vec_block_diag(vec_block_diag(data_1[:, None, :], data_1[:, None, :]), data_2[:, None, :]),
102+
)
103+
104+
x0 = mod.ssm["initial_state"].eval(
105+
{
106+
"beta_exog1": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX),
107+
"beta_exog2": np.array([5.0], dtype=config.floatX),
108+
}
109+
)
110+
np.testing.assert_allclose(x0, np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=config.floatX))
111+
112+
45113
def test_filter_scans_time_varying_design_matrix(rng):
46114
time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100)
47115
data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx)

tests/statespace/models/structural/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_extract_components_from_idata(rng):
106106
filter_prior = mod.sample_conditional_prior(prior)
107107
comp_prior = mod.extract_components_from_idata(filter_prior)
108108
comp_states = comp_prior.filtered_prior.coords["state"].values
109-
expected_states = ["LevelTrend[level]", "LevelTrend[trend]", "seasonal", "exog[a]", "exog[b]"]
109+
expected_states = ["level_trend[level]", "level_trend[trend]", "seasonal", "exog[a]", "exog[b]"]
110110
missing = set(comp_states) - set(expected_states)
111111

112112
assert len(missing) == 0, missing

0 commit comments

Comments
 (0)