11from abc import ABC , abstractmethod
2- from typing import Literal , get_args
32
43import pandas as pd
54import pymc as pm
65
76from model .modular .utilities import (
7+ PRIOR_DEFAULT_KWARGS ,
88 ColumnType ,
9+ PoolingType ,
910 get_X_data ,
10- hierarchical_prior_to_requested_depth ,
11+ make_hierarchical_prior ,
1112 select_data_columns ,
1213)
1314from patsy import dmatrix
1415
15- PoolingType = Literal ["none" , "complete" , "partial" , None ]
16- valid_pooling = get_args (PoolingType )
17-
18-
19- def _validate_pooling_params (pooling_columns : ColumnType , pooling : PoolingType ):
20- """
21- Helper function to validate inputs to a GLM component.
22-
23- Parameters
24- ----------
25- pooling_columns: str or list of str
26- Data columns used to construct a hierarchical prior
27- pooling: str
28- Type of pooling to use in the component
29-
30- Returns
31- -------
32- None
33- """
34- if pooling_columns is not None and pooling == "complete" :
35- raise ValueError ("Index data provided but complete pooling was requested" )
36- if pooling_columns is None and pooling != "complete" :
37- raise ValueError (
38- "Index data must be provided for partial pooling (pooling = 'partial') or no pooling "
39- "(pooling = 'none')"
40- )
41-
4216
4317class GLMModel (ABC ):
4418 """Base class for GLM components. Subclasses should implement the build method to construct the component."""
@@ -91,7 +65,7 @@ def __init__(
9165 self ,
9266 name : str | None = None ,
9367 * ,
94- pooling_cols : ColumnType = None ,
68+ pooling_columns : ColumnType = None ,
9569 pooling : PoolingType = "complete" ,
9670 hierarchical_params : dict | None = None ,
9771 prior : str = "Normal" ,
@@ -108,7 +82,7 @@ def __init__(
10882 ----------
10983 name: str, optional
11084 Name of the intercept term. If None, a default name is generated based on the index_data.
111- pooling_cols : str or list of str, optional
85+ pooling_columns : str or list of str, optional
11286 Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
11387 If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
11488 to right, with the mean of each subsequent level centered on the mean of the previous level.
@@ -133,37 +107,41 @@ def __init__(
133107 Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
134108
135109 """
136- _validate_pooling_params (pooling_cols , pooling )
137-
138- self .pooling_cols = pooling_cols
139110 self .hierarchical_params = hierarchical_params if hierarchical_params is not None else {}
140- self .pooling = pooling if pooling_cols is not None else "complete"
111+ self .pooling = pooling
141112
142113 self .prior = prior
143114 self .prior_params = prior_params if prior_params is not None else {}
144115
145- if pooling_cols is None :
146- pooling_cols = []
147- elif isinstance (pooling_cols , str ):
148- pooling_cols = [pooling_cols ]
116+ if pooling_columns is None :
117+ pooling_columns = []
118+ elif isinstance (pooling_columns , str ):
119+ pooling_columns = [pooling_columns ]
149120
150- name = name or f"Intercept(pooling_cols={ pooling_cols } )"
121+ self .pooling_columns = pooling_columns
122+ name = name or f"Intercept(pooling_cols={ pooling_columns } )"
151123
152124 super ().__init__ (name = name )
153125
154126 def build (self , model : pm .Model | None = None ):
155127 model = pm .modelcontext (model )
156128 with model :
157129 if self .pooling == "complete" :
158- intercept = getattr (pm , self .prior .title ())(f"{ self .name } " , ** self .prior_params )
130+ prior_params = PRIOR_DEFAULT_KWARGS [self .prior ].copy ()
131+ prior_params .update (self .prior_params )
132+
133+ intercept = getattr (pm , self .prior )(f"{ self .name } " , ** prior_params )
159134 return intercept
160135
161- intercept = hierarchical_prior_to_requested_depth (
136+ intercept = make_hierarchical_prior (
162137 self .name ,
163- df = get_X_data (model )[ self . pooling_cols ] ,
138+ X = get_X_data (model ),
164139 model = model ,
140+ pooling_columns = self .pooling_columns ,
165141 dims = None ,
166- no_pooling = self .pooling == "none" ,
142+ pooling = self .pooling ,
143+ prior = self .prior ,
144+ prior_kwargs = self .prior_params ,
167145 ** self .hierarchical_params ,
168146 )
169147
@@ -219,8 +197,6 @@ def __init__(
219197 prior_params:
220198 Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
221199 """
222- _validate_pooling_params (pooling_columns , pooling )
223-
224200 self .feature_columns = feature_columns
225201 self .pooling = pooling
226202 self .pooling_columns = pooling_columns
@@ -248,7 +224,7 @@ def build(self, model=None):
248224 )
249225 return X @ beta
250226
251- beta = hierarchical_prior_to_requested_depth (
227+ beta = make_hierarchical_prior (
252228 self .name ,
253229 self .index_data ,
254230 model = model ,
@@ -318,7 +294,6 @@ def __init__(
318294 offset_dist: str, one of ["zerosum", "normal", "laplace"]
319295 Name of the distribution to use for the offset distribution. Default is "zerosum"
320296 """
321- _validate_pooling_params (index_data , pooling )
322297 self .name = name if name else f"Spline({ feature_column } )"
323298 self .feature_column = feature_column
324299 self .n_knots = n_knots
@@ -352,7 +327,7 @@ def build(self, model: pm.Model | None = None):
352327
353328 elif self .pooling_columns is not None :
354329 X = select_data_columns (self .pooling_columns , model )
355- beta = hierarchical_prior_to_requested_depth (
330+ beta = make_hierarchical_prior (
356331 name = self .name ,
357332 X = X ,
358333 model = model ,
0 commit comments