6
6
import pymc as pm
7
7
import pytensor .tensor as pt
8
8
9
- from model .modular .utilities import ColumnType , hierarchical_prior_to_requested_depth
9
+ from model .modular .utilities import ColumnType , get_X_data , hierarchical_prior_to_requested_depth
10
10
from patsy import dmatrix
11
11
12
12
POOLING_TYPES = Literal ["none" , "complete" , "partial" ]
@@ -105,7 +105,6 @@ def __init__(
105
105
prior_params : dict | None = None ,
106
106
):
107
107
"""
108
- TODO: Update signature docs
109
108
Class to represent an intercept term in a GLM model.
110
109
111
110
By intercept, it is meant any constant term in the model that is not a function of any input data. This can be
@@ -116,21 +115,15 @@ def __init__(
116
115
----------
117
116
name: str, optional
118
117
Name of the intercept term. If None, a default name is generated based on the index_data.
119
- index_data: Series or DataFrame, optional
120
- Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
121
- levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
122
- and depth increasing to the right.
123
-
124
- The index of the index_data must match the index of the observed data.
125
- prior: str, optional
126
- Name of the PyMC distribution to use for the intercept term. Default is "Normal".
118
+ pooling_cols: str or list of str, optional
119
+ Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
120
+ If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
121
+ to right, with the mean of each subsequent level centered on the mean of the previous level.
127
122
pooling: str, one of ["none", "complete", "partial"], default "complete"
128
123
Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the
129
124
index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated
130
125
as coming from the same group. If "partial", a hierarchical prior is constructed that shares information
131
126
across groups in the index_data.
132
- prior_params: dict, optional
133
- Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
134
127
hierarchical_params: dict, optional
135
128
Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function.
136
129
Options include:
@@ -141,6 +134,11 @@ def __init__(
141
134
Default is {"alpha": 2, "beta": 1}
142
135
offset_dist: str, one of ["zerosum", "normal", "laplace"]
143
136
Name of the distribution to use for the offset distribution. Default is "zerosum"
137
+ prior: str, optional
138
+ Name of the PyMC distribution to use for the intercept term. Default is "Normal".
139
+ prior_params: dict, optional
140
+ Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
141
+
144
142
"""
145
143
_validate_pooling_params (pooling_cols , pooling )
146
144
@@ -158,25 +156,25 @@ def __init__(
158
156
159
157
data_name = ", " .join (pooling_cols )
160
158
self .name = name or f"Constant(pooling_cols={ data_name } )"
159
+
161
160
super ().__init__ ()
162
161
163
- def build (self , model = None ):
162
+ def build (self , model : pm . Model | None = None ):
164
163
model = pm .modelcontext (model )
165
164
with model :
166
165
if self .pooling == "complete" :
167
166
intercept = getattr (pm , self .prior )(f"{ self .name } " , ** self .prior_params )
168
167
return intercept
169
168
170
- [i for i , col in enumerate (model .coords ["feature" ]) if col in self .pooling_cols ]
171
-
172
169
intercept = hierarchical_prior_to_requested_depth (
173
170
self .name ,
174
- model . X_df [self .pooling_cols ], # TODO: Reconsider this
171
+ df = get_X_data ( model ) [self .pooling_cols ],
175
172
model = model ,
176
173
dims = None ,
177
174
no_pooling = self .pooling == "none" ,
178
175
** self .hierarchical_params ,
179
176
)
177
+
180
178
return intercept
181
179
182
180
0 commit comments