Skip to content

Commit 7b3cdcf

Browse files
Update intercept docstring
1 parent 404550d commit 7b3cdcf

File tree

3 files changed

+25
-24
lines changed

3 files changed

+25
-24
lines changed

pymc_experimental/model/modular/components.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pymc as pm
77
import pytensor.tensor as pt
88

9-
from model.modular.utilities import ColumnType, hierarchical_prior_to_requested_depth
9+
from model.modular.utilities import ColumnType, get_X_data, hierarchical_prior_to_requested_depth
1010
from patsy import dmatrix
1111

1212
POOLING_TYPES = Literal["none", "complete", "partial"]
@@ -105,7 +105,6 @@ def __init__(
105105
prior_params: dict | None = None,
106106
):
107107
"""
108-
TODO: Update signature docs
109108
Class to represent an intercept term in a GLM model.
110109
111110
By intercept, it is meant any constant term in the model that is not a function of any input data. This can be
@@ -116,21 +115,15 @@ def __init__(
116115
----------
117116
name: str, optional
118117
Name of the intercept term. If None, a default name is generated based on the index_data.
119-
index_data: Series or DataFrame, optional
120-
Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
121-
levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
122-
and depth increasing to the right.
123-
124-
The index of the index_data must match the index of the observed data.
125-
prior: str, optional
126-
Name of the PyMC distribution to use for the intercept term. Default is "Normal".
118+
pooling_cols: str or list of str, optional
119+
Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
120+
If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
121+
to right, with the mean of each subsequent level centered on the mean of the previous level.
127122
pooling: str, one of ["none", "complete", "partial"], default "complete"
128123
Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the
129124
index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated
130125
as coming from the same group. If "partial", a hierarchical prior is constructed that shares information
131126
across groups in the index_data.
132-
prior_params: dict, optional
133-
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
134127
hierarchical_params: dict, optional
135128
Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function.
136129
Options include:
@@ -141,6 +134,11 @@ def __init__(
141134
Default is {"alpha": 2, "beta": 1}
142135
offset_dist: str, one of ["zerosum", "normal", "laplace"]
143136
Name of the distribution to use for the offset distribution. Default is "zerosum"
137+
prior: str, optional
138+
Name of the PyMC distribution to use for the intercept term. Default is "Normal".
139+
prior_params: dict, optional
140+
Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
141+
144142
"""
145143
_validate_pooling_params(pooling_cols, pooling)
146144

@@ -158,25 +156,25 @@ def __init__(
158156

159157
data_name = ", ".join(pooling_cols)
160158
self.name = name or f"Constant(pooling_cols={data_name})"
159+
161160
super().__init__()
162161

163-
def build(self, model=None):
162+
def build(self, model: pm.Model | None = None):
164163
model = pm.modelcontext(model)
165164
with model:
166165
if self.pooling == "complete":
167166
intercept = getattr(pm, self.prior)(f"{self.name}", **self.prior_params)
168167
return intercept
169168

170-
[i for i, col in enumerate(model.coords["feature"]) if col in self.pooling_cols]
171-
172169
intercept = hierarchical_prior_to_requested_depth(
173170
self.name,
174-
model.X_df[self.pooling_cols], # TODO: Reconsider this
171+
df=get_X_data(model)[self.pooling_cols],
175172
model=model,
176173
dims=None,
177174
no_pooling=self.pooling == "none",
178175
**self.hierarchical_params,
179176
)
177+
180178
return intercept
181179

182180

pymc_experimental/model/modular/likelihood.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ def __init__(self, target_col: ColumnType, data: pd.DataFrame):
4343

4444
# TODO: Reconsider this (two sources of nearly the same info not good)
4545
X_df = data.drop(columns=[target_col])
46-
X_data = X_df.copy()
46+
4747
self.column_labels = {}
48-
for col, dtype in X_data.dtypes.to_dict().items():
48+
for col, dtype in X_df.dtypes.to_dict().items():
4949
if dtype.name.startswith("float"):
5050
pass
5151
elif dtype.name == "object":
5252
# TODO: We definitely need to save these if we want to factorize predict data
53-
col_array, labels = pd.factorize(X_data[col], sort=True)
54-
X_data[col] = col_array.astype("float64")
53+
col_array, labels = pd.factorize(X_df[col], sort=True)
54+
X_df[col] = col_array.astype("float64")
5555
self.column_labels[col] = {label: i for i, label in enumerate(labels.values)}
5656
elif dtype.name.startswith("int"):
57-
X_data[col] = X_data[col].astype("float64")
57+
X_df[col] = X_df[col].astype("float64")
5858
else:
5959
raise NotImplementedError(
6060
f"Haven't decided how to handle the following type: {dtype.name}"
@@ -63,14 +63,13 @@ def __init__(self, target_col: ColumnType, data: pd.DataFrame):
6363
self.obs_dim = data.index.name
6464
coords = {
6565
self.obs_dim: data.index.values,
66-
"feature": list(X_data.columns),
66+
"feature": list(X_df.columns),
6767
}
6868
with self._get_model_class(coords) as self.model:
69-
self.model.X_df = X_df # FIXME: Definitely not a solution
7069
pm.Data(f"{target_col}_observed", data[target_col], dims=self.obs_dim)
7170
pm.Data(
7271
"X_data",
73-
X_data,
72+
X_df,
7473
dims=(self.obs_dim, "feature"),
7574
shape=(None, len(coords["feature"])),
7675
)

pymc_experimental/model/modular/utilities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def _get_x_cols(
3737
return model["X_data"][:, cols_idx]
3838

3939

40+
def get_X_data(model, data_name="X_data"):
41+
return model[data_name]
42+
43+
4044
def make_level_maps(df: pd.DataFrame, ordered_levels: list[str]):
4145
"""
4246
For each row of data, create a mapping between levels of a arbitrary set of levels defined by `ordered_levels`.

0 commit comments

Comments
 (0)