Skip to content

Commit 155dcd6

Browse files
Refactor intercept class
1 parent b243326 commit 155dcd6

File tree

5 files changed

+252
-130
lines changed

5 files changed

+252
-130
lines changed

pymc_experimental/model/modular/components.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,18 @@
11
from abc import ABC, abstractmethod
2-
from typing import Literal, get_args
32

43
import pandas as pd
54
import pymc as pm
65

76
from model.modular.utilities import (
7+
PRIOR_DEFAULT_KWARGS,
88
ColumnType,
9+
PoolingType,
910
get_X_data,
10-
hierarchical_prior_to_requested_depth,
11+
make_hierarchical_prior,
1112
select_data_columns,
1213
)
1314
from patsy import dmatrix
1415

15-
PoolingType = Literal["none", "complete", "partial", None]
16-
valid_pooling = get_args(PoolingType)
17-
18-
19-
def _validate_pooling_params(pooling_columns: ColumnType, pooling: PoolingType):
20-
"""
21-
Helper function to validate inputs to a GLM component.
22-
23-
Parameters
24-
----------
25-
pooling_columns: str or list of str
26-
Data columns used to construct a hierarchical prior
27-
pooling: str
28-
Type of pooling to use in the component
29-
30-
Returns
31-
-------
32-
None
33-
"""
34-
if pooling_columns is not None and pooling == "complete":
35-
raise ValueError("Index data provided but complete pooling was requested")
36-
if pooling_columns is None and pooling != "complete":
37-
raise ValueError(
38-
"Index data must be provided for partial pooling (pooling = 'partial') or no pooling "
39-
"(pooling = 'none')"
40-
)
41-
4216

4317
class GLMModel(ABC):
4418
"""Base class for GLM components. Subclasses should implement the build method to construct the component."""
@@ -91,7 +65,7 @@ def __init__(
9165
self,
9266
name: str | None = None,
9367
*,
94-
pooling_cols: ColumnType = None,
68+
pooling_columns: ColumnType = None,
9569
pooling: PoolingType = "complete",
9670
hierarchical_params: dict | None = None,
9771
prior: str = "Normal",
@@ -108,7 +82,7 @@ def __init__(
10882
----------
10983
name: str, optional
11084
Name of the intercept term. If None, a default name is generated based on the index_data.
111-
pooling_cols: str or list of str, optional
85+
pooling_columns: str or list of str, optional
11286
Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
11387
If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
11488
to right, with the mean of each subsequent level centered on the mean of the previous level.
@@ -133,37 +107,41 @@ def __init__(
133107
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
134108
135109
"""
136-
_validate_pooling_params(pooling_cols, pooling)
137-
138-
self.pooling_cols = pooling_cols
139110
self.hierarchical_params = hierarchical_params if hierarchical_params is not None else {}
140-
self.pooling = pooling if pooling_cols is not None else "complete"
111+
self.pooling = pooling
141112

142113
self.prior = prior
143114
self.prior_params = prior_params if prior_params is not None else {}
144115

145-
if pooling_cols is None:
146-
pooling_cols = []
147-
elif isinstance(pooling_cols, str):
148-
pooling_cols = [pooling_cols]
116+
if pooling_columns is None:
117+
pooling_columns = []
118+
elif isinstance(pooling_columns, str):
119+
pooling_columns = [pooling_columns]
149120

150-
name = name or f"Intercept(pooling_cols={pooling_cols})"
121+
self.pooling_columns = pooling_columns
122+
name = name or f"Intercept(pooling_cols={pooling_columns})"
151123

152124
super().__init__(name=name)
153125

154126
def build(self, model: pm.Model | None = None):
155127
model = pm.modelcontext(model)
156128
with model:
157129
if self.pooling == "complete":
158-
intercept = getattr(pm, self.prior.title())(f"{self.name}", **self.prior_params)
130+
prior_params = PRIOR_DEFAULT_KWARGS[self.prior].copy()
131+
prior_params.update(self.prior_params)
132+
133+
intercept = getattr(pm, self.prior)(f"{self.name}", **prior_params)
159134
return intercept
160135

161-
intercept = hierarchical_prior_to_requested_depth(
136+
intercept = make_hierarchical_prior(
162137
self.name,
163-
df=get_X_data(model)[self.pooling_cols],
138+
X=get_X_data(model),
164139
model=model,
140+
pooling_columns=self.pooling_columns,
165141
dims=None,
166-
no_pooling=self.pooling == "none",
142+
pooling=self.pooling,
143+
prior=self.prior,
144+
prior_kwargs=self.prior_params,
167145
**self.hierarchical_params,
168146
)
169147

@@ -219,8 +197,6 @@ def __init__(
219197
prior_params:
220198
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
221199
"""
222-
_validate_pooling_params(pooling_columns, pooling)
223-
224200
self.feature_columns = feature_columns
225201
self.pooling = pooling
226202
self.pooling_columns = pooling_columns
@@ -248,7 +224,7 @@ def build(self, model=None):
248224
)
249225
return X @ beta
250226

251-
beta = hierarchical_prior_to_requested_depth(
227+
beta = make_hierarchical_prior(
252228
self.name,
253229
self.index_data,
254230
model=model,
@@ -318,7 +294,6 @@ def __init__(
318294
offset_dist: str, one of ["zerosum", "normal", "laplace"]
319295
Name of the distribution to use for the offset distribution. Default is "zerosum"
320296
"""
321-
_validate_pooling_params(index_data, pooling)
322297
self.name = name if name else f"Spline({feature_column})"
323298
self.feature_column = feature_column
324299
self.n_knots = n_knots
@@ -352,7 +327,7 @@ def build(self, model: pm.Model | None = None):
352327

353328
elif self.pooling_columns is not None:
354329
X = select_data_columns(self.pooling_columns, model)
355-
beta = hierarchical_prior_to_requested_depth(
330+
beta = make_hierarchical_prior(
356331
name=self.name,
357332
X=X,
358333
model=model,

pymc_experimental/model/modular/likelihood.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Literal, get_args
44

55
import arviz as az
6-
import numpy as np
76
import pandas as pd
87
import pymc as pm
98
import pytensor.tensor as pt
@@ -14,7 +13,7 @@
1413
from pytensor.tensor.random.type import RandomType
1514

1615
from pymc_experimental.model.marginal.marginal_model import MarginalModel
17-
from pymc_experimental.model.modular.utilities import ColumnType
16+
from pymc_experimental.model.modular.utilities import ColumnType, encode_categoricals
1817

1918
LIKELIHOOD_TYPES = Literal["lognormal", "logt", "mixture", "unmarginalized-mixture"]
2019
valid_likelihoods = get_args(LIKELIHOOD_TYPES)
@@ -42,33 +41,14 @@ def __init__(self, target_col: ColumnType, data: pd.DataFrame):
4241
[target_col] = target_col
4342
self.target_col = target_col
4443

45-
# TODO: Reconsider this (two sources of nearly the same info not good)
4644
X_df = data.drop(columns=[target_col])
4745

4846
self.obs_dim = data.index.name
4947
self.coords = {
5048
self.obs_dim: data.index.values,
5149
}
5250

53-
for col, dtype in X_df.dtypes.to_dict().items():
54-
if dtype.name.startswith("float"):
55-
pass
56-
elif dtype.name == "object":
57-
# TODO: We definitely need to save these if we want to factorize predict data
58-
col_array, labels = pd.factorize(X_df[col], sort=True)
59-
X_df[col] = col_array.astype("float64")
60-
self.coords[col] = labels
61-
elif dtype.name.startswith("int"):
62-
_data = X_df[col].copy()
63-
X_df[col] = X_df[col].astype("float64")
64-
assert np.all(
65-
_data == X_df[col].astype("int")
66-
), "Information was lost in conversion to float"
67-
68-
else:
69-
raise NotImplementedError(
70-
f"Haven't decided how to handle the following type: {dtype.name}"
71-
)
51+
X_df, self.coords = encode_categoricals(X_df, self.coords)
7252

7353
numeric_cols = [
7454
col for col, dtype in X_df.dtypes.to_dict().items() if dtype.name.startswith("float")

0 commit comments

Comments
 (0)