66import pymc as pm
77import pytensor .tensor as pt
88
9- from model .modular .utilities import ColumnType , hierarchical_prior_to_requested_depth
9+ from model .modular .utilities import ColumnType , get_X_data , hierarchical_prior_to_requested_depth
1010from patsy import dmatrix
1111
1212POOLING_TYPES = Literal ["none" , "complete" , "partial" ]
@@ -105,7 +105,6 @@ def __init__(
105105 prior_params : dict | None = None ,
106106 ):
107107 """
108- TODO: Update signature docs
109108 Class to represent an intercept term in a GLM model.
110109
111110 By intercept, it is meant any constant term in the model that is not a function of any input data. This can be
@@ -116,21 +115,15 @@ def __init__(
116115 ----------
117116 name: str, optional
118117 Name of the intercept term. If None, a default name is generated based on the index_data.
119- index_data: Series or DataFrame, optional
120- Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
121- levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
122- and depth increasing to the right.
123-
124- The index of the index_data must match the index of the observed data.
125- prior: str, optional
126- Name of the PyMC distribution to use for the intercept term. Default is "Normal".
118+ pooling_cols: str or list of str, optional
119+ Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
120+ If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
121+ to right, with the mean of each subsequent level centered on the mean of the previous level.
127122 pooling: str, one of ["none", "complete", "partial"], default "complete"
128123 Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the
129124 index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated
130125 as coming from the same group. If "partial", a hierarchical prior is constructed that shares information
131126 across groups in the index_data.
132- prior_params: dict, optional
133- Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
134127 hierarchical_params: dict, optional
135128 Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function.
136129 Options include:
@@ -141,6 +134,11 @@ def __init__(
141134 Default is {"alpha": 2, "beta": 1}
142135 offset_dist: str, one of ["zerosum", "normal", "laplace"]
143136 Name of the distribution to use for the offset distribution. Default is "zerosum"
137+ prior: str, optional
138+ Name of the PyMC distribution to use for the intercept term. Default is "Normal".
139+ prior_params: dict, optional
140+ Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
141+
144142 """
145143 _validate_pooling_params (pooling_cols , pooling )
146144
@@ -158,25 +156,25 @@ def __init__(
158156
159157 data_name = ", " .join (pooling_cols )
160158 self .name = name or f"Constant(pooling_cols={ data_name } )"
159+
161160 super ().__init__ ()
162161
163- def build (self , model = None ):
162+ def build (self , model : pm . Model | None = None ):
164163 model = pm .modelcontext (model )
165164 with model :
166165 if self .pooling == "complete" :
167166 intercept = getattr (pm , self .prior )(f"{ self .name } " , ** self .prior_params )
168167 return intercept
169168
170- [i for i , col in enumerate (model .coords ["feature" ]) if col in self .pooling_cols ]
171-
172169 intercept = hierarchical_prior_to_requested_depth (
173170 self .name ,
174- model . X_df [self .pooling_cols ], # TODO: Reconsider this
171+ df = get_X_data ( model ) [self .pooling_cols ],
175172 model = model ,
176173 dims = None ,
177174 no_pooling = self .pooling == "none" ,
178175 ** self .hierarchical_params ,
179176 )
177+
180178 return intercept
181179
182180
0 commit comments