Skip to content

Commit 8a630ce

Browse files
Refactor regression component
1 parent 155dcd6 commit 8a630ce

File tree

3 files changed

+80
-28
lines changed

3 files changed

+80
-28
lines changed

pymc_experimental/model/modular/components.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
PRIOR_DEFAULT_KWARGS,
88
ColumnType,
99
PoolingType,
10+
at_least_list,
1011
get_X_data,
1112
make_hierarchical_prior,
1213
select_data_columns,
@@ -112,13 +113,8 @@ def __init__(
112113

113114
self.prior = prior
114115
self.prior_params = prior_params if prior_params is not None else {}
116+
self.pooling_columns = at_least_list(pooling_columns)
115117

116-
if pooling_columns is None:
117-
pooling_columns = []
118-
elif isinstance(pooling_columns, str):
119-
pooling_columns = [pooling_columns]
120-
121-
self.pooling_columns = pooling_columns
122118
name = name or f"Intercept(pooling_cols={pooling_columns})"
123119

124120
super().__init__(name=name)
@@ -158,7 +154,7 @@ def __init__(
158154
pooling: PoolingType = "complete",
159155
pooling_columns: ColumnType | None = None,
160156
hierarchical_params: dict | None = None,
161-
**prior_params,
157+
prior_params: dict | None = None,
162158
):
163159
"""
164160
Class to represent a regression component in a GLM model.
@@ -197,12 +193,13 @@ def __init__(
197193
prior_params:
198194
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
199195
"""
200-
self.feature_columns = feature_columns
196+
self.feature_columns = at_least_list(feature_columns)
201197
self.pooling = pooling
202-
self.pooling_columns = pooling_columns
198+
self.pooling_columns = at_least_list(pooling_columns)
203199

204200
self.prior = prior
205-
self.prior_params = prior_params
201+
self.prior_params = {} if prior_params is None else prior_params
202+
self.hierarchical_params = {} if hierarchical_params is None else hierarchical_params
206203

207204
name = name if name else f"Regression({feature_columns})"
208205

@@ -213,23 +210,27 @@ def build(self, model=None):
213210
feature_dim = f"{self.name}_features"
214211

215212
if feature_dim not in model.coords:
216-
model.add_coord(feature_dim, self.X.columns)
213+
model.add_coord(feature_dim, self.feature_columns)
217214

218215
with model:
219-
X = select_data_columns(get_X_data(model), self.feature_columns)
216+
full_X = get_X_data(model)
217+
X = select_data_columns(self.feature_columns, model, squeeze=False)
220218

221219
if self.pooling == "complete":
222-
beta = getattr(pm, self.prior)(
223-
f"{self.name}", **self.prior_params, dims=[feature_dim]
224-
)
220+
prior_params = PRIOR_DEFAULT_KWARGS[self.prior].copy()
221+
prior_params.update(self.prior_params)
222+
223+
beta = getattr(pm, self.prior)(f"{self.name}", **prior_params, dims=[feature_dim])
225224
return X @ beta
226225

227226
beta = make_hierarchical_prior(
228-
self.name,
229-
self.index_data,
227+
name=self.name,
228+
X=full_X,
229+
pooling=self.pooling,
230+
pooling_columns=self.pooling_columns,
230231
model=model,
231232
dims=[feature_dim],
232-
no_pooling=self.pooling == "none",
233+
**self.hierarchical_params,
233234
)
234235

235236
regression_effect = (X * beta.T).sum(axis=-1)

pymc_experimental/model/modular/utilities.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def select_data_columns(
4242
cols: str | Sequence[str] | None,
4343
model: pm.Model | None = None,
4444
data_name: str = "X_data",
45+
squeeze=True,
4546
) -> pt.TensorVariable | None:
4647
"""
4748
Create a tensor variable representing a subset of independent data columns.
@@ -72,7 +73,7 @@ def select_data_columns(
7273
cols_idx = [model.coords["feature"].index(col) for col in cols]
7374

7475
# Single columns are returned as 1d arrays
75-
if len(cols_idx) == 1:
76+
if len(cols_idx) == 1 and squeeze:
7677
cols_idx = cols_idx[0]
7778

7879
return get_X_data(model, data_name=data_name)[:, cols_idx]
@@ -110,6 +111,14 @@ def encode_categoricals(df, coords):
110111
return df, coords
111112

112113

114+
def at_least_list(columns: ColumnType):
115+
if columns is None:
116+
columns = []
117+
elif isinstance(columns, str):
118+
columns = [columns]
119+
return columns
120+
121+
113122
def make_level_maps(X: SharedVariable, coords: dict[str, tuple | None], ordered_levels: list[str]):
114123
r"""
115124
For each row of data, create a mapping between levels of a arbitrary set of levels defined by `ordered_levels`.
@@ -304,7 +313,7 @@ def make_partial_pooled_hierarchy(
304313
prior_params.update(prior_kwargs)
305314

306315
with model:
307-
beta = Prior(f"{name}_effect", **prior_params, dims=dims)
316+
beta = Prior(f"{name}", **prior_params, dims=dims)
308317

309318
for i, (last_level, level) in enumerate(itertools.pairwise([None, *pooling_columns])):
310319
if i == 0:
@@ -359,13 +368,17 @@ def make_unpooled_hierarchy(
359368
beta = Prior(f"{name}_mu", **prior_kwargs, dims=dims)
360369

361370
for i, (last_level, level) in enumerate(itertools.pairwise([None, *levels])):
362-
sigma = make_sigma(f"{name}_{level}_sigma", sigma_dist, sigma_kwargs, dims)
371+
if i == 0:
372+
sigma_dims = dims
373+
else:
374+
sigma_dims = [*dims, last_level] if dims is not None else [last_level]
375+
beta_dims = [*dims, level] if dims is not None else [level]
376+
377+
sigma = make_sigma(f"{name}_{level}_effect", sigma_dist, sigma_kwargs, sigma_dims)
363378

364379
prior_kwargs["mu"] = beta[..., idx_maps[i]]
365380
scale_name = "b" if prior == "Laplace" else "sigma"
366-
prior_kwargs[scale_name] = sigma
367-
368-
beta_dims = [*dims, level] if dims is not None else [level]
381+
prior_kwargs[scale_name] = sigma[..., idx_maps[i]]
369382

370383
beta = Prior(f"{name}_{level}_effect", **prior_kwargs, dims=beta_dims)
371384

tests/model/modular/test_components.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import pymc as pm
44
import pytest
55

6-
from model.modular.utilities import encode_categoricals
6+
from model.modular.utilities import at_least_list, encode_categoricals
77

8-
from pymc_experimental.model.modular.components import Intercept, PoolingType
8+
from pymc_experimental.model.modular.components import Intercept, PoolingType, Regression
99

1010

1111
@pytest.fixture(scope="session")
@@ -51,5 +51,43 @@ def test_intercept(pooling: PoolingType, prior, model):
5151
assert np.unique(x).shape[0] == 1
5252

5353

54-
def test_regression():
55-
pass
54+
@pytest.mark.parametrize("pooling", ["partial", "none", "complete"], ids=str)
55+
@pytest.mark.parametrize("prior", ["Normal", "Laplace", "StudentT"], ids=str)
56+
@pytest.mark.parametrize(
57+
"feature_columns", ["income", ["age", "income"]], ids=["single", "multiple"]
58+
)
59+
def test_regression(pooling: PoolingType, prior, feature_columns, model):
60+
regression = Regression(
61+
name=None,
62+
feature_columns=feature_columns,
63+
prior=prior,
64+
pooling=pooling,
65+
pooling_columns="city",
66+
)
67+
68+
temp_model = model.copy()
69+
xb = regression.build(temp_model)
70+
assert f"Regression({feature_columns})_features" in temp_model.coords.keys()
71+
72+
if pooling != "complete":
73+
assert f"Regression({feature_columns})_city_effect" in temp_model.named_vars
74+
assert f"Regression({feature_columns})_city_effect_sigma" in temp_model.named_vars
75+
76+
if pooling == "partial":
77+
assert (
78+
f"Regression({feature_columns})_city_effect_offset" in temp_model.named_vars_to_dims
79+
)
80+
else:
81+
assert f"Regression({feature_columns})" in temp_model.named_vars
82+
83+
xb_val = xb.eval()
84+
85+
X, beta = xb.owner.inputs[0].owner.inputs
86+
beta_val = beta.eval()
87+
n_features = len(at_least_list(feature_columns))
88+
89+
if pooling != "complete":
90+
assert xb_val.shape[0] == len(model.coords["obs_idx"])
91+
assert np.unique(beta_val).shape[0] == len(model.coords["city"]) * n_features
92+
else:
93+
assert np.unique(beta_val).shape[0] == n_features

0 commit comments

Comments
 (0)