1717from foundry .glm .family .survival import SurvivalFamily
1818from foundry .glm .util import NoWeightModule , Stopping
1919from foundry .hessian import hessian
20+ from foundry .penalty import L2
2021from foundry .util import FitFailedException , is_invalid , get_to_kwargs , to_tensor , to_2d
2122
2223ModelMatrix = Union [np .ndarray , pd .DataFrame , dict ]
2728
2829
2930class Glm (BaseEstimator ):
31+ """
32+ :param family: Either a :class:`foundry.glm.family.Family`, or a string alias. You can see available aliases with
33+ ``Glm.family_aliases()``.
34+ :param penalty: A multiplier for L2 penalty on coefficients. Can be a single float, or a dictionary of these to
35+ support different penalties per distribution-parameter. Instead of floats, can also pass functions that take a
36+ ``torch.nn.Module`` as the first argument and that module's param-names as second argument, and returns a scalar
37+ penalty that will be applied to the log-prob.
38+ :param predict_params: Many distributions have multiple parameters: for example the normal distribution has a
39+ location and scale parameter. If a single dataframe/matrix is passed to ``fit()``, the default behavior is to
40+ use these to separately predict each of loc/scale. Sometimes this is not desired: for example, we only want to use
41+ predictors to predict the location, and the scale should be 'intercept only'. This can be accomplished with
42+ `predict_params=['loc']` (replacing 'loc' with the relevant name(s) for your distribution of interest). For finer-
43+ grained control (e.g. using some predictors for some params and others for others), see options about passing a
44+ dictionary to ``fit()``.
45+ """
3046
3147 def __init__ (self ,
3248 family : Union [str , Family ],
3349 penalty : Union [float , Sequence [float ], Dict [str , float ]] = 0. ,
3450 predict_params : Optional [Sequence [str ]] = None ):
51+
3552 self .family = family
3653 self .penalty = penalty
3754 self .predict_params = predict_params
@@ -69,19 +86,28 @@ def fit(self,
6986 groups : Optional [np .ndarray ] = None ,
7087 ** kwargs ) -> 'Glm' :
7188 """
72- :param X:
73- :param y:
74- :param sample_weight:
75- :param groups:
76- :param reset:
77- :param callbacks:
78- :param stopping: args/kwargs to pass to :class:`foundry.glm.util.Stopping` (e.g. ``(.01,)`` would
79- use abstol of .01).
80- :param max_iter:
81- :param max_loss:
82- :param verbose:
83- :param estimate_laplace_coefs:
84- :return:
89+ :param X: A array/dataframe of predictors, or a dictionary of these. If a dict, then keys should correspond to
90+ ``self.family.params``.
91+ :param y: An array of targets. This can instead be a dictionary, with the target-value in the "value" entry,
92+ and additional auxiliary information (e.g. sample-weights, upper/lower censoring for survival modeling) in
93+ other entries.
94+ :param sample_weight: The weight for each row. If performing cross-validation, this argument should not be used,
95+ as sklearn does not support it; instead, pass a :class:`foundry.util.SliceDict` for the ``y`` argument, with a
96+ 'sample_weight' entry.
97+ :param groups: todo
98+ :param reset: If calling ``fit()`` more than once, should the module/weights be reinitialized. Default True.
99+ :param callbacks: A list of callbacks: functions that take the ``Glm`` instance as a first argument and the
100+ train-loss as a second argument.
101+ :param stopping: Controls stopping based on converging loss/parameters. This argument is passed to
102+ :class:`foundry.glm.util.Stopping` (e.g. ``(.01,)`` would use abstol of .01).
103+ :param max_iter: The max. number of iterations before stopping training regardless of convergence. Default 200.
104+ :param max_loss: If training stops and loss is higher than this, a class:`foundry.util.FitFailedException` will
105+ be raised and fitting will be retried with a different set of inits.
106+ :param verbose: Whether to allow print statements and a progress bar during training. Default True.
107+ :param estimate_laplace_coefs: If true, then after fitting, the hessian of the optimzed parameters will be
108+ estimated; this can then be used for confidence-intervals and statistical inference (see ``coef_dataframe_``).
109+ Can set to False if you want to save time and skip this step.
110+ :return: This ``Glm`` instance.
85111 """
86112 self .family = self ._init_family (self .family )
87113
@@ -92,6 +118,12 @@ def fit(self,
92118 warn ("`groups` argument will be ignored because self.penalty is a single value not a sequence." )
93119 return self ._fit (X = X , y = y , sample_weight = sample_weight , ** kwargs )
94120
121+ @staticmethod
122+ def family_aliases () -> dict :
123+ out = Family .aliases .copy ()
124+ out .update ({f'survival_{ nm } ' : f for f , nm in SurvivalFamily .aliases .items ()})
125+ return out
126+
95127 @staticmethod
96128 def _init_family (family : Union [Family , str ]) -> Family :
97129 if isinstance (family , str ):
@@ -231,6 +263,7 @@ def _get_xdict(self, X: ModelMatrix) -> Dict[str, torch.Tensor]:
231263 if isinstance (X , dict ):
232264 Xdict = X .copy ()
233265 else :
266+ # TODO: if originally passed a dict but are now passing a dataframe, this will lead to cryptic errors later
234267 Xdict = {p : X for p in self .expected_model_mat_params_ }
235268
236269 # validate:
@@ -351,8 +384,8 @@ def module_factory(cls,
351384 """
352385 Given a model-matrix and output-dim, produce a ``torch.nn.Module`` that predicts a distribution-parameter. The
353386 default produces a ``torch.nn.Linear`` layer. Additionally, this function returns a dictionary whose keys are
354- param-names and whose values are the names of the individual elements. For the default case, each weight is
355- named according to the column-names in X (if X is a dataframe).
387+ param-names and whose values are the names of the individual param- elements. For the default case, each weight
388+ is named according to the column-names in X (or "x{i}" if X is not a dataframe).
356389
357390 :param X: A dataframe or ndarray.
358391 :param output_dim: The number of output dimensions.
@@ -370,17 +403,18 @@ def module_factory(cls,
370403 columns = list (X .columns ) if hasattr (X , 'columns' ) else [f'x{ i } ' for i in range (X .shape [1 ])]
371404
372405 module_param_names = {'bias' : [], 'weight' : []}
373- for i in range (output_dim ):
374- module_param_names ['bias' ].append (f'y{ i } __bias' )
375- module_param_names ['weight' ].append ([f'y{ i } __{ c } ' for c in columns ])
406+ if output_dim == 1 :
407+ # for most common case of 1d output, param names are just feature-names
408+ module_param_names ['bias' ].append ('bias' )
409+ module_param_names ['weight' ].append (columns )
410+ else :
411+ # if multi-output, we prefix with output idx:
412+ for i in range (output_dim ):
413+ module_param_names ['bias' ].append (f'y{ i } __bias' )
414+ module_param_names ['weight' ].append ([f'y{ i } __{ c } ' for c in columns ])
376415
377416 return module , module_param_names
378417
379- @classmethod
380- def penalty_from_module (cls , module : torch .nn .Module , penalty_multi : float ) -> torch .Tensor :
381- feature_dist = torch .distributions .Normal (loc = 0 , scale = 1 / penalty_multi ** .5 , validate_args = False )
382- return - feature_dist .log_prob (module .weight ).sum ()
383-
384418 def _init_optimizer (self ) -> torch .optim .Optimizer :
385419 return torch .optim .LBFGS (
386420 self .module_ .parameters (), max_iter = 10 , line_search_fn = 'strong_wolfe' , lr = .25
@@ -433,19 +467,27 @@ def _get_penalty(self) -> torch.Tensor:
433467 if not self .penalty :
434468 return torch .zeros (1 , ** get_to_kwargs (self .module_ ))
435469
470+ # standardize to dictionary with param-names:
436471 if isinstance (self .penalty , dict ):
437- penalty_multis = self .penalty
438472 if set (self .penalty ) != set (self .module_ .keys ()):
439473 raise ValueError (
440474 f"``self.penalty.keys()`` is { set (self .penalty )} , but expected { set (self .module_ .keys ())} "
441475 )
442476 else :
443- penalty_multis = {k : self .penalty for k in self .module_ .keys ()}
444-
445- to_sum = [torch .tensor (0. , ** get_to_kwargs (self .module_ ))]
446- for dp , module in self .module_ .items ():
477+ self .penalty = {k : self .penalty for k in self .module_ .keys ()}
478+
479+ # standardize to values = callables:
480+ for param_name in list (self .penalty ):
481+ maybe_callable = self .penalty [param_name ]
482+ if not callable (maybe_callable ):
483+ # if not callable, it's a multiplier --i.e. L2's 'precision'
484+ self .penalty [param_name ] = L2 (precision = maybe_callable )
485+
486+ # call each:
487+ to_sum = []
488+ for param_name , penalty_fun in self .penalty .items ():
447489 to_sum .append (
448- self .penalty_from_module ( module , penalty_multi = penalty_multis [ dp ])
490+ penalty_fun ( self .module_ [ param_name ], self . _module_param_names_ [ param_name ])
449491 )
450492 return torch .stack (to_sum ).sum ()
451493
0 commit comments