11from abc import ABC , abstractmethod
22
3- import pandas as pd
3+ import numpy as np
44import pymc as pm
5+ import pytensor .tensor as pt
56
67from model .modular .utilities import (
78 PRIOR_DEFAULT_KWARGS ,
1314 select_data_columns ,
1415)
1516from patsy import dmatrix
17+ from pytensor .graph import Apply , Op
1618
1719
1820class GLMModel (ABC ):
@@ -113,7 +115,7 @@ def __init__(
113115
114116 self .prior = prior
115117 self .prior_params = prior_params if prior_params is not None else {}
116- self .pooling_columns = at_least_list ( pooling_columns )
118+ self .pooling_columns = pooling_columns
117119
118120 name = name or f"Intercept(pooling_cols={ pooling_columns } )"
119121
@@ -193,9 +195,9 @@ def __init__(
193195 prior_params:
194196 Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
195197 """
196- self .feature_columns = at_least_list ( feature_columns )
198+ self .feature_columns = feature_columns
197199 self .pooling = pooling
198- self .pooling_columns = at_least_list ( pooling_columns )
200+ self .pooling_columns = pooling_columns
199201
200202 self .prior = prior
201203 self .prior_params = {} if prior_params is None else prior_params
@@ -210,7 +212,7 @@ def build(self, model=None):
210212 feature_dim = f"{ self .name } _features"
211213
212214 if feature_dim not in model .coords :
213- model .add_coord (feature_dim , self .feature_columns )
215+ model .add_coord (feature_dim , at_least_list ( self .feature_columns ) )
214216
215217 with model :
216218 full_X = get_X_data (model )
@@ -237,17 +239,55 @@ def build(self, model=None):
237239 return regression_effect
238240
239241
240- class Spline (Regression ):
242+ class SplineTensor (Op ):
243+ def __init__ (self , name , df = 10 , degree = 3 ):
244+ """
245+ Thin wrapper around patsy dmatrix, allowing for the creation of spline basis functions given a symbolic input.
246+
247+ Parameters
248+ ----------
249+ name: str, optional
250+ Name of the spline basis function.
251+ df: int
252+ Number of basis functions to generate
253+ degree: int
254+ Degree of the spline basis
255+ """
256+ self .name = name if name else ""
257+ self .df = df
258+ self .degree = degree
259+
260+ def make_node (self , x ):
261+ inputs = [pt .as_tensor (x )]
262+ outputs = [pt .dmatrix (f"{ self .name } _spline_basis" )]
263+
264+ return Apply (self , inputs , outputs )
265+
266+ def perform (self , node : Apply , inputs : list [np .ndarray ], outputs : list [list [None ]]) -> None :
267+ [x ] = inputs
268+
269+ outputs [0 ][0 ] = np .asarray (
270+ dmatrix (f"bs({ self .name } , df={ self .df } , degree={ self .degree } ) - 1" , data = {self .name : x })
271+ )
272+
273+
274+ def pt_spline (x , name = None , df = 10 , degree = 3 ) -> pt .Variable :
275+ return SplineTensor (name = name , df = df , degree = degree )(x )
276+
277+
278+ class Spline (GLMModel ):
241279 def __init__ (
242280 self ,
243281 name : str ,
244282 * ,
245283 feature_column : str | None = None ,
246284 n_knots : int = 10 ,
247- prior : str = "Normal" ,
248- index_data : pd .Series | None = None ,
285+ spline_degree : int = 3 ,
249286 pooling : PoolingType = "complete" ,
250- ** prior_params ,
287+ pooling_columns : ColumnType | None = None ,
288+ prior : str = "Normal" ,
289+ prior_params : dict | None = None ,
290+ hierarchical_params : dict | None = None ,
251291 ):
252292 """
253293 Class to represent a spline component in a GLM model.
@@ -263,25 +303,23 @@ def __init__(
263303 ----------
264304 name: str, optional
265305 Name of the intercept term. If None, a default name is generated based on the index_data.
266- n_knots: int, default 10
267- Number of knots to use in the spline basis.
268306 feature_column: str
269307 Column of the independent data to use in the spline.
270- index_data: Series or DataFrame, optional
271- Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as
272- levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy,
273- and depth increasing to the right.
274-
275- The index of the index_data must match the index of the observed data.
276- prior: str, optional
277- Name of the PyMC distribution to use for the intercept term. Default is "Normal".
308+ n_knots: int, default 10
309+ Number of knots to use in the spline basis.
310+ spline_degree: int, default 3
311+ Degree of the spline basis.
278312 pooling: str, one of ["none", "complete", "partial"], default "complete"
279313 Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the
280314 index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated
281315 as coming from the same group. If "partial", a hierarchical prior is constructed that shares information
282316 across groups in the index_data.
283- curve_type: str, one of ["log", "abc", "ns", "nss", "box-cox"]
284- Type of curve to build. For details, see the build_curve function.
317+ pooling_columns: str or list of str, optional
318+ Columns of the independent data to use as labels for pooling. These columns will be treated as categorical.
319+ If None, no pooling is applied. If a list is provided, a "telescoping" hierarchy is constructed from left
320+ to right, with the mean of each subsequent level centered on the mean of the previous level.
321+ prior: str, optional
322+ Name of the PyMC distribution to use for the intercept term. Default is "Normal".
285323 prior_params: dict, optional
286324 Additional keyword arguments to pass to the PyMC distribution specified by the prior argument.
287325 hierarchical_params: dict, optional
@@ -295,45 +333,49 @@ def __init__(
295333 offset_dist: str, one of ["zerosum", "normal", "laplace"]
296334 Name of the distribution to use for the offset distribution. Default is "zerosum"
297335 """
298- self .name = name if name else f"Spline({ feature_column } )"
299336 self .feature_column = feature_column
300337 self .n_knots = n_knots
338+ self .spline_degree = spline_degree
339+
301340 self .prior = prior
302- self .prior_params = prior_params
341+ self .prior_params = {} if prior_params is None else prior_params
342+ self .hierarchical_params = {} if hierarchical_params is None else hierarchical_params
303343
344+ self .pooling = pooling
345+ self .pooling_columns = pooling_columns
346+
347+ name = name if name else f"Spline({ feature_column } , df={ n_knots } , degree={ spline_degree } )"
304348 super ().__init__ (name = name )
305349
306350 def build (self , model : pm .Model | None = None ):
307351 model = pm .modelcontext (model )
308- model .add_coord (f"{ self .name } _spline" , range (self .n_knots ))
352+ spline_dim = f"{ self .name } _knots"
353+ model .add_coord (spline_dim , range (self .n_knots ))
309354
310355 with model :
311- spline_data = {
312- self .feature_column : select_data_columns (
313- get_X_data (model ).get_value (), self .feature_column
314- )
315- }
316-
317- X_spline = dmatrix (
318- f"bs({ self .feature_column } , df={ self .n_knots } , degree=3) - 1" ,
319- data = spline_data ,
320- return_type = "dataframe" ,
356+ X_spline = pt_spline (
357+ select_data_columns (self .feature_column , model ),
358+ name = self .feature_column ,
359+ df = self .n_knots ,
360+ degree = self .spline_degree ,
321361 )
322362
323363 if self .pooling == "complete" :
324- beta = getattr (pm , self .prior )(
325- f"{ self .name } " , ** self .prior_params , dims = f"{ self .feature_column } _spline"
326- )
364+ prior_params = PRIOR_DEFAULT_KWARGS [self .prior ].copy ()
365+ prior_params .update (self .prior_params )
366+
367+ beta = getattr (pm , self .prior )(f"{ self .name } " , ** prior_params , dims = [spline_dim ])
327368 return X_spline @ beta
328369
329370 elif self .pooling_columns is not None :
330- X = select_data_columns (self .pooling_columns , model )
331371 beta = make_hierarchical_prior (
332372 name = self .name ,
333- X = X ,
373+ X = get_X_data (model ),
374+ pooling = self .pooling ,
375+ pooling_columns = self .pooling_columns ,
334376 model = model ,
335- dims = [f" { self . feature_column } _spline" ],
336- no_pooling = self .pooling == "none" ,
377+ dims = [spline_dim ],
378+ ** self .hierarchical_params ,
337379 )
338380
339381 spline_effect = (X_spline * beta .T ).sum (axis = - 1 )
0 commit comments