Skip to content

Commit ab3b4ed

Browse files
Refactor spline component
1 parent 8a630ce commit ab3b4ed

File tree

3 files changed

+99
-44
lines changed

3 files changed

+99
-44
lines changed

pymc_experimental/model/modular/components.py

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
22

3-
import pandas as pd
3+
import numpy as np
44
import pymc as pm
5+
import pytensor.tensor as pt
56

67
from model.modular.utilities import (
78
PRIOR_DEFAULT_KWARGS,
@@ -13,6 +14,7 @@
1314
select_data_columns,
1415
)
1516
from patsy import dmatrix
17+
from pytensor.graph import Apply, Op
1618

1719

1820
class GLMModel(ABC):
@@ -113,7 +115,7 @@ def __init__(
113115

114116
self.prior = prior
115117
self.prior_params = prior_params if prior_params is not None else {}
116-
self.pooling_columns = at_least_list(pooling_columns)
118+
self.pooling_columns = pooling_columns
117119

118120
name = name or f"Intercept(pooling_cols={pooling_columns})"
119121

@@ -193,9 +195,9 @@ def __init__(
193195
prior_params:
194196
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
195197
"""
196-
self.feature_columns = at_least_list(feature_columns)
198+
self.feature_columns = feature_columns
197199
self.pooling = pooling
198-
self.pooling_columns = at_least_list(pooling_columns)
200+
self.pooling_columns = pooling_columns
199201

200202
self.prior = prior
201203
self.prior_params = {} if prior_params is None else prior_params
@@ -210,7 +212,7 @@ def build(self, model=None):
210212
feature_dim = f"{self.name}_features"
211213

212214
if feature_dim not in model.coords:
213-
model.add_coord(feature_dim, self.feature_columns)
215+
model.add_coord(feature_dim, at_least_list(self.feature_columns))
214216

215217
with model:
216218
full_X = get_X_data(model)
@@ -237,17 +239,55 @@ def build(self, model=None):
237239
return regression_effect
238240

239241

240-
class Spline(Regression):
242+
class SplineTensor(Op):
243+
def __init__(self, name, df=10, degree=3):
244+
"""
245+
Thin wrapper around patsy dmatrix, allowing for the creation of spline basis functions given a symbolic input.
246+
247+
Parameters
248+
----------
249+
name: str, optional
250+
Name of the spline basis function.
251+
df: int
252+
Number of basis functions to generate
253+
degree: int
254+
Degree of the spline basis
255+
"""
256+
self.name = name if name else ""
257+
self.df = df
258+
self.degree = degree
259+
260+
def make_node(self, x):
261+
inputs = [pt.as_tensor(x)]
262+
outputs = [pt.dmatrix(f"{self.name}_spline_basis")]
263+
264+
return Apply(self, inputs, outputs)
265+
266+
def perform(self, node: Apply, inputs: list[np.ndarray], outputs: list[list[None]]) -> None:
267+
[x] = inputs
268+
269+
outputs[0][0] = np.asarray(
270+
dmatrix(f"bs({self.name}, df={self.df}, degree={self.degree}) - 1", data={self.name: x})
271+
)
272+
273+
274+
def pt_spline(x, name=None, df=10, degree=3) -> pt.Variable:
275+
return SplineTensor(name=name, df=df, degree=degree)(x)
276+
277+
278+
class Spline(GLMModel):
241279
def __init__(
242280
self,
243281
name: str,
244282
*,
245283
feature_column: str | None = None,
246284
n_knots: int = 10,
247-
prior: str = "Normal",
248-
index_data: pd.Series | None = None,
285+
spline_degree: int = 3,
249286
pooling: PoolingType = "complete",
250-
**prior_params,
287+
pooling_columns: ColumnType | None = None,
288+
prior: str = "Normal",
289+
prior_params: dict | None = None,
290+
hierarchical_params: dict | None = None,
251291
):
252292
"""
253293
Class to represent a spline component in a GLM model.
@@ -263,25 +303,23 @@ def __init__(
263303
----------
264304
name: str, optional
265305
Name of the intercept term. If None, a default name is generated based on the index_data.
266-
n_knots: int, default 10
267-
Number of knots to use in the spline basis.
268306
feature_column: str
269307
Column of the independent data to use in the spline.
270-
index_data: Series or DataFrame, optional
271-
Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
272-
levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
273-
and depth increasing to the right.
274-
275-
The index of the index_data must match the index of the observed data.
276-
prior: str, optional
277-
Name of the PyMC distribution to use for the intercept term. Default is "Normal".
308+
n_knots: int, default 10
309+
Number of knots to use in the spline basis.
310+
spline_degree: int, default 3
311+
Degree of the spline basis.
278312
pooling: str, one of ["none", "complete", "partial"], default "complete"
279313
Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the
280314
index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated
281315
as coming from the same group. If "partial", a hierarchical prior is constructed that shares information
282316
across groups in the index_data.
283-
curve_type: str, one of ["log", "abc", "ns", "nss", "box-cox"]
284-
Type of curve to build. For details, see the build_curve function.
317+
pooling_columns: str or list of str, optional
318+
Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
319+
If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
320+
to right, with the mean of each subsequent level centered on the mean of the previous level.
321+
prior: str, optional
322+
Name of the PyMC distribution to use for the intercept term. Default is "Normal".
285323
prior_params: dict, optional
286324
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
287325
hierarchical_params: dict, optional
@@ -295,45 +333,49 @@ def __init__(
295333
offset_dist: str, one of ["zerosum", "normal", "laplace"]
296334
Name of the distribution to use for the offset distribution. Default is "zerosum"
297335
"""
298-
self.name = name if name else f"Spline({feature_column})"
299336
self.feature_column = feature_column
300337
self.n_knots = n_knots
338+
self.spline_degree = spline_degree
339+
301340
self.prior = prior
302-
self.prior_params = prior_params
341+
self.prior_params = {} if prior_params is None else prior_params
342+
self.hierarchical_params = {} if hierarchical_params is None else hierarchical_params
303343

344+
self.pooling = pooling
345+
self.pooling_columns = pooling_columns
346+
347+
name = name if name else f"Spline({feature_column}, df={n_knots}, degree={spline_degree})"
304348
super().__init__(name=name)
305349

306350
def build(self, model: pm.Model | None = None):
307351
model = pm.modelcontext(model)
308-
model.add_coord(f"{self.name}_spline", range(self.n_knots))
352+
spline_dim = f"{self.name}_knots"
353+
model.add_coord(spline_dim, range(self.n_knots))
309354

310355
with model:
311-
spline_data = {
312-
self.feature_column: select_data_columns(
313-
get_X_data(model).get_value(), self.feature_column
314-
)
315-
}
316-
317-
X_spline = dmatrix(
318-
f"bs({self.feature_column}, df={self.n_knots}, degree=3) - 1",
319-
data=spline_data,
320-
return_type="dataframe",
356+
X_spline = pt_spline(
357+
select_data_columns(self.feature_column, model),
358+
name=self.feature_column,
359+
df=self.n_knots,
360+
degree=self.spline_degree,
321361
)
322362

323363
if self.pooling == "complete":
324-
beta = getattr(pm, self.prior)(
325-
f"{self.name}", **self.prior_params, dims=f"{self.feature_column}_spline"
326-
)
364+
prior_params = PRIOR_DEFAULT_KWARGS[self.prior].copy()
365+
prior_params.update(self.prior_params)
366+
367+
beta = getattr(pm, self.prior)(f"{self.name}", **prior_params, dims=[spline_dim])
327368
return X_spline @ beta
328369

329370
elif self.pooling_columns is not None:
330-
X = select_data_columns(self.pooling_columns, model)
331371
beta = make_hierarchical_prior(
332372
name=self.name,
333-
X=X,
373+
X=get_X_data(model),
374+
pooling=self.pooling,
375+
pooling_columns=self.pooling_columns,
334376
model=model,
335-
dims=[f"{self.feature_column}_spline"],
336-
no_pooling=self.pooling == "none",
377+
dims=[spline_dim],
378+
**self.hierarchical_params,
337379
)
338380

339381
spline_effect = (X_spline * beta.T).sum(axis=-1)

pymc_experimental/model/modular/utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def select_data_columns(
6363
if cols is None:
6464
return
6565

66-
if isinstance(cols, str):
67-
cols = [cols]
66+
cols = at_least_list(cols)
6867

6968
missing_cols = [col for col in cols if col not in model.coords["feature"]]
7069
if missing_cols:
@@ -407,6 +406,7 @@ def make_hierarchical_prior(
407406
**hierarchy_kwargs,
408407
):
409408
model = pm.modelcontext(model)
409+
pooling_columns = at_least_list(pooling_columns)
410410

411411
if pooling == "none":
412412
return make_unpooled_hierarchy(

tests/model/modular/test_components.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from model.modular.utilities import at_least_list, encode_categoricals
77

8-
from pymc_experimental.model.modular.components import Intercept, PoolingType, Regression
8+
from pymc_experimental.model.modular.components import Intercept, PoolingType, Regression, Spline
99

1010

1111
@pytest.fixture(scope="session")
@@ -91,3 +91,16 @@ def test_regression(pooling: PoolingType, prior, feature_columns, model):
9191
assert np.unique(beta_val).shape[0] == len(model.coords["city"]) * n_features
9292
else:
9393
assert np.unique(beta_val).shape[0] == n_features
94+
95+
96+
@pytest.mark.parametrize("pooling", ["partial", "none", "complete"], ids=str)
97+
@pytest.mark.parametrize("prior", ["Normal", "Laplace", "StudentT"], ids=str)
98+
def test_spline(pooling: PoolingType, prior, model):
99+
spline = Spline(
100+
name=None, feature_column="income", prior=prior, pooling=pooling, pooling_columns="city"
101+
)
102+
103+
temp_model = model.copy()
104+
xb = spline.build(temp_model)
105+
106+
assert "Spline(income, df=10, degree=3)_knots" in temp_model.coords.keys()

0 commit comments

Comments
 (0)