11from abc import ABC , abstractmethod
2- from collections .abc import Sequence
32from typing import Literal , get_args
43
54import pandas as pd
65import pymc as pm
7- import pytensor .tensor as pt
86
9- from model .modular .utilities import ColumnType , get_X_data , hierarchical_prior_to_requested_depth
7+ from model .modular .utilities import (
8+ ColumnType ,
9+ get_X_data ,
10+ hierarchical_prior_to_requested_depth ,
11+ select_data_columns ,
12+ )
1013from patsy import dmatrix
1114
12- POOLING_TYPES = Literal ["none" , "complete" , "partial" ]
13- valid_pooling = get_args (POOLING_TYPES )
15+ PoolingType = Literal ["none" , "complete" , "partial" , None ]
16+ valid_pooling = get_args (PoolingType )
1417
1518
16- def _validate_pooling_params (pooling_columns : ColumnType , pooling : POOLING_TYPES ):
19+ def _validate_pooling_params (pooling_columns : ColumnType , pooling : PoolingType ):
1720 """
1821 Helper function to validate inputs to a GLM component.
1922
2023 Parameters
2124 ----------
22- index_data: Series or DataFrame
23- Index data used to build hierarchical priors
24-
25+ pooling_columns: str or list of str
26+ Data columns used to construct a hierarchical prior
2527 pooling: str
2628 Type of pooling to use in the component
2729
@@ -38,25 +40,13 @@ def _validate_pooling_params(pooling_columns: ColumnType, pooling: POOLING_TYPES
3840 )
3941
4042
41- def _get_x_cols (
42- cols : str | Sequence [str ],
43- model : pm .Model | None = None ,
44- ) -> pt .TensorVariable :
45- model = pm .modelcontext (model )
46- # Don't upcast a single column to a colum matrix
47- if isinstance (cols , str ):
48- [cols_idx ] = [i for i , col in enumerate (model .coords ["feature" ]) if col == cols ]
49- else :
50- cols_idx = [i for i , col in enumerate (model .coords ["feature" ]) if col is cols ]
51- return model ["X_data" ][:, cols_idx ]
52-
53-
5443class GLMModel (ABC ):
5544 """Base class for GLM components. Subclasses should implement the build method to construct the component."""
5645
57- def __init__ (self ):
46+ def __init__ (self , name ):
5847 self .model = None
5948 self .compiled = False
49+ self .name = name
6050
6151 @abstractmethod
6252 def build (self , model = None ):
@@ -68,6 +58,9 @@ def __add__(self, other):
6858 def __mul__ (self , other ):
6959 return MultiplicativeGLMComponent (self , other )
7060
61+ def __str__ (self ):
62+ return self .name
63+
7164
7265class AdditiveGLMComponent (GLMModel ):
7366 """Class to represent an additive combination of GLM components"""
@@ -99,7 +92,7 @@ def __init__(
9992 name : str | None = None ,
10093 * ,
10194 pooling_cols : ColumnType = None ,
102- pooling : POOLING_TYPES = "complete" ,
95+ pooling : PoolingType = "complete" ,
10396 hierarchical_params : dict | None = None ,
10497 prior : str = "Normal" ,
10598 prior_params : dict | None = None ,
@@ -154,16 +147,15 @@ def __init__(
154147 elif isinstance (pooling_cols , str ):
155148 pooling_cols = [pooling_cols ]
156149
157- data_name = ", " .join (pooling_cols )
158- self .name = name or f"Constant(pooling_cols={ data_name } )"
150+ name = name or f"Intercept(pooling_cols={ pooling_cols } )"
159151
160- super ().__init__ ()
152+ super ().__init__ (name = name )
161153
162154 def build (self , model : pm .Model | None = None ):
163155 model = pm .modelcontext (model )
164156 with model :
165157 if self .pooling == "complete" :
166- intercept = getattr (pm , self .prior )(f"{ self .name } " , ** self .prior_params )
158+ intercept = getattr (pm , self .prior . title () )(f"{ self .name } " , ** self .prior_params )
167159 return intercept
168160
169161 intercept = hierarchical_prior_to_requested_depth (
@@ -181,11 +173,13 @@ def build(self, model: pm.Model | None = None):
181173class Regression (GLMModel ):
182174 def __init__ (
183175 self ,
184- name : str ,
185- X : pd .DataFrame ,
176+ name : str | None = None ,
177+ * ,
178+ feature_columns : ColumnType | None = None ,
186179 prior : str = "Normal" ,
187- index_data : pd .Series = None ,
188- pooling : POOLING_TYPES = "complete" ,
180+ pooling : PoolingType = "complete" ,
181+ pooling_columns : ColumnType | None = None ,
182+ hierarchical_params : dict | None = None ,
189183 ** prior_params ,
190184 ):
191185 """
@@ -199,26 +193,19 @@ def __init__(
199193 ----------
200194 name: str, optional
201195 Name of the intercept term. If None, a default name is generated based on the index_data.
202- X: DataFrame
203- Exogenous data used to build the regression component. Each column of the DataFrame represents a feature
204- used in the regression. Index of the DataFrame should match the index of the observed data.
205- index_data: Series or DataFrame, optional
206- Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
207- levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
208- and depth increasing to the right.
209-
210- The index of the index_data must match the index of the observed data.
196+ feature_columns: str or list of str
197+ Columns of the independent data to use in the regression.
211198 prior: str, optional
212199 Name of the PyMC distribution to use for the intercept term. Default is "Normal".
213200 pooling: str, one of ["none", "complete", "partial"], default "complete"
214201 Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the
215202 index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated
216203 as coming from the same group. If "partial", a hierarchical prior is constructed that shares information
217204 across groups in the index_data.
218- curve_type : str, one of ["log", "abc", "ns", "nss", "box-cox"]
219- Type of curve to build. For details, see the build_curve function .
220- prior_params: dict, optional
221- Additional keyword arguments to pass to the PyMC distribution specified by the prior argument .
205+ pooling_columns : str or list of str, optional
206+ Columns of the independent data to use as labels for pooling. These columns will be treated as categorical .
207+ If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
208+ to right, with the mean of each subsequent level centered on the mean of the previous level .
222209 hierarchical_params: dict, optional
223210 Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function.
224211 Options include:
@@ -229,34 +216,37 @@ def __init__(
229216 Default is {"alpha": 2, "beta": 1}
230217 offset_dist: str, one of ["zerosum", "normal", "laplace"]
231218 Name of the distribution to use for the offset distribution. Default is "zerosum"
219+ prior_params:
220+ Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
232221 """
233- _validate_pooling_params (index_data , pooling )
222+ _validate_pooling_params (pooling_columns , pooling )
234223
235- self .name = name
236- self .X = X
237- self .index_data = index_data
224+ self .feature_columns = feature_columns
238225 self .pooling = pooling
226+ self .pooling_columns = pooling_columns
239227
240228 self .prior = prior
241229 self .prior_params = prior_params
242230
243- super ().__init__ ()
231+ name = name if name else f"Regression({ feature_columns } )"
232+
233+ super ().__init__ (name = name )
244234
245235 def build (self , model = None ):
246236 model = pm .modelcontext (model )
247237 feature_dim = f"{ self .name } _features"
248- obs_dim = self .X .index .name
249238
250239 if feature_dim not in model .coords :
251240 model .add_coord (feature_dim , self .X .columns )
252241
253242 with model :
254- X_pt = pm .Data (f"{ self .name } _data" , self .X .values , dims = [obs_dim , feature_dim ])
243+ X = select_data_columns (get_X_data (model ), self .feature_columns )
244+
255245 if self .pooling == "complete" :
256246 beta = getattr (pm , self .prior )(
257247 f"{ self .name } " , ** self .prior_params , dims = [feature_dim ]
258248 )
259- return X_pt @ beta
249+ return X @ beta
260250
261251 beta = hierarchical_prior_to_requested_depth (
262252 self .name ,
@@ -266,19 +256,20 @@ def build(self, model=None):
266256 no_pooling = self .pooling == "none" ,
267257 )
268258
269- regression_effect = (X_pt * beta .T ).sum (axis = - 1 )
259+ regression_effect = (X * beta .T ).sum (axis = - 1 )
270260 return regression_effect
271261
272262
273263class Spline (Regression ):
274264 def __init__ (
275265 self ,
276266 name : str ,
267+ * ,
268+ feature_column : str | None = None ,
277269 n_knots : int = 10 ,
278- spline_data : pd .Series | pd .DataFrame | None = None ,
279270 prior : str = "Normal" ,
280271 index_data : pd .Series | None = None ,
281- pooling : POOLING_TYPES = "complete" ,
272+ pooling : PoolingType = "complete" ,
282273 ** prior_params ,
283274 ):
284275 """
@@ -297,10 +288,8 @@ def __init__(
297288 Name of the intercept term. If None, a default name is generated based on the index_data.
298289 n_knots: int, default 10
299290 Number of knots to use in the spline basis.
300- spline_data: Series or DataFrame
301- Exogenous data to be interpolated using basis splines. If Series, must have a name attribute. If dataframe,
302- must have exactly one column. In either case, the index of the data should match the index of the observed
303- data.
291+ feature_column: str
292+ Column of the independent data to use in the spline.
304293 index_data: Series or DataFrame, optional
305294 Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
306295 levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
@@ -330,17 +319,46 @@ def __init__(
330319 Name of the distribution to use for the offset distribution. Default is "zerosum"
331320 """
332321 _validate_pooling_params (index_data , pooling )
322+ self .name = name if name else f"Spline({ feature_column } )"
323+ self .feature_column = feature_column
324+ self .n_knots = n_knots
325+ self .prior = prior
326+ self .prior_params = prior_params
333327
334- spline_features = dmatrix (
335- f"bs(maturity_years, df={ n_knots } , degree=3) - 1" ,
336- {"maturity_years" : spline_data },
337- )
338- X = pd .DataFrame (
339- spline_features ,
340- index = spline_data .index ,
341- columns = [f"Spline_{ i } " for i in range (n_knots )],
342- )
328+ super ().__init__ (name = name )
343329
344- super ().__init__ (
345- name = name , X = X , prior = prior , index_data = index_data , pooling = pooling , ** prior_params
346- )
330+ def build (self , model : pm .Model | None = None ):
331+ model = pm .modelcontext (model )
332+ model .add_coord (f"{ self .name } _spline" , range (self .n_knots ))
333+
334+ with model :
335+ spline_data = {
336+ self .feature_column : select_data_columns (
337+ get_X_data (model ).get_value (), self .feature_column
338+ )
339+ }
340+
341+ X_spline = dmatrix (
342+ f"bs({ self .feature_column } , df={ self .n_knots } , degree=3) - 1" ,
343+ data = spline_data ,
344+ return_type = "dataframe" ,
345+ )
346+
347+ if self .pooling == "complete" :
348+ beta = getattr (pm , self .prior )(
349+ f"{ self .name } " , ** self .prior_params , dims = f"{ self .feature_column } _spline"
350+ )
351+ return X_spline @ beta
352+
353+ elif self .pooling_columns is not None :
354+ X = select_data_columns (self .pooling_columns , model )
355+ beta = hierarchical_prior_to_requested_depth (
356+ name = self .name ,
357+ X = X ,
358+ model = model ,
359+ dims = [f"{ self .feature_column } _spline" ],
360+ no_pooling = self .pooling == "none" ,
361+ )
362+
363+ spline_effect = (X_spline * beta .T ).sum (axis = - 1 )
364+ return spline_effect
0 commit comments