Skip to content

Commit 5824ea5

Browse files
authored
Merge pull request #1 from strongio/feature/custom-penalties
Support for custom penalties
2 parents 5a97e06 + 13df98d commit 5824ea5

File tree

2 files changed

+120
-29
lines changed

2 files changed

+120
-29
lines changed

foundry/glm/glm.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from foundry.glm.family.survival import SurvivalFamily
1818
from foundry.glm.util import NoWeightModule, Stopping
1919
from foundry.hessian import hessian
20+
from foundry.penalty import L2
2021
from foundry.util import FitFailedException, is_invalid, get_to_kwargs, to_tensor, to_2d
2122

2223
ModelMatrix = Union[np.ndarray, pd.DataFrame, dict]
@@ -27,11 +28,27 @@
2728

2829

2930
class Glm(BaseEstimator):
31+
"""
32+
:param family: Either a :class:`foundry.glm.family.Family`, or a string alias. You can see available aliases with
33+
``Glm.family_aliases()``.
34+
:param penalty: A multiplier for L2 penalty on coefficients. Can be a single float, or a dictionary of these to
35+
support different penalties per distribution-parameter. Instead of floats, can also pass functions that take a
36+
``torch.nn.Module`` as the first argument and that module's param-names as second argument, and returns a scalar
37+
penalty that will be applied to the log-prob.
38+
:param predict_params: Many distributions have multiple parameters: for example the normal distribution has a
39+
location and scale parameter. If a single dataframe/matrix is passed to ``fit()``, the default behavior is to
40+
use these to separately predict each of loc/scale. Sometimes this is not desired: for example, we only want to use
41+
predictors to predict the location, and the scale should be 'intercept only'. This can be accomplished with
42+
`predict_params=['loc']` (replacing 'loc' with the relevant name(s) for your distribution of interest). For finer-
43+
grained control (e.g. using some predictors for some params and others for others), see options about passing a
44+
dictionary to ``fit()``.
45+
"""
3046

3147
def __init__(self,
3248
family: Union[str, Family],
3349
penalty: Union[float, Sequence[float], Dict[str, float]] = 0.,
3450
predict_params: Optional[Sequence[str]] = None):
51+
3552
self.family = family
3653
self.penalty = penalty
3754
self.predict_params = predict_params
@@ -69,19 +86,28 @@ def fit(self,
6986
groups: Optional[np.ndarray] = None,
7087
**kwargs) -> 'Glm':
7188
"""
72-
:param X:
73-
:param y:
74-
:param sample_weight:
75-
:param groups:
76-
:param reset:
77-
:param callbacks:
78-
:param stopping: args/kwargs to pass to :class:`foundry.glm.util.Stopping` (e.g. ``(.01,)`` would
79-
use abstol of .01).
80-
:param max_iter:
81-
:param max_loss:
82-
:param verbose:
83-
:param estimate_laplace_coefs:
84-
:return:
89+
:param X: A array/dataframe of predictors, or a dictionary of these. If a dict, then keys should correspond to
90+
``self.family.params``.
91+
:param y: An array of targets. This can instead be a dictionary, with the target-value in the "value" entry,
92+
and additional auxiliary information (e.g. sample-weights, upper/lower censoring for survival modeling) in
93+
other entries.
94+
:param sample_weight: The weight for each row. If performing cross-validation, this argument should not be used,
95+
as sklearn does not support it; instead, pass a :class:`foundry.util.SliceDict` for the ``y`` argument, with a
96+
'sample_weight' entry.
97+
:param groups: todo
98+
:param reset: If calling ``fit()`` more than once, should the module/weights be reinitialized. Default True.
99+
:param callbacks: A list of callbacks: functions that take the ``Glm`` instance as a first argument and the
100+
train-loss as a second argument.
101+
:param stopping: Controls stopping based on converging loss/parameters. This argument is passed to
102+
:class:`foundry.glm.util.Stopping` (e.g. ``(.01,)`` would use abstol of .01).
103+
:param max_iter: The max. number of iterations before stopping training regardless of convergence. Default 200.
104+
:param max_loss: If training stops and loss is higher than this, a class:`foundry.util.FitFailedException` will
105+
be raised and fitting will be retried with a different set of inits.
106+
:param verbose: Whether to allow print statements and a progress bar during training. Default True.
107+
:param estimate_laplace_coefs: If true, then after fitting, the hessian of the optimzed parameters will be
108+
estimated; this can then be used for confidence-intervals and statistical inference (see ``coef_dataframe_``).
109+
Can set to False if you want to save time and skip this step.
110+
:return: This ``Glm`` instance.
85111
"""
86112
self.family = self._init_family(self.family)
87113

@@ -92,6 +118,12 @@ def fit(self,
92118
warn("`groups` argument will be ignored because self.penalty is a single value not a sequence.")
93119
return self._fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
94120

121+
@staticmethod
122+
def family_aliases() -> dict:
123+
out = Family.aliases.copy()
124+
out.update({f'survival_{nm}': f for f, nm in SurvivalFamily.aliases.items()})
125+
return out
126+
95127
@staticmethod
96128
def _init_family(family: Union[Family, str]) -> Family:
97129
if isinstance(family, str):
@@ -231,6 +263,7 @@ def _get_xdict(self, X: ModelMatrix) -> Dict[str, torch.Tensor]:
231263
if isinstance(X, dict):
232264
Xdict = X.copy()
233265
else:
266+
# TODO: if originally passed a dict but are now passing a dataframe, this will lead to cryptic errors later
234267
Xdict = {p: X for p in self.expected_model_mat_params_}
235268

236269
# validate:
@@ -351,8 +384,8 @@ def module_factory(cls,
351384
"""
352385
Given a model-matrix and output-dim, produce a ``torch.nn.Module`` that predicts a distribution-parameter. The
353386
default produces a ``torch.nn.Linear`` layer. Additionally, this function returns a dictionary whose keys are
354-
param-names and whose values are the names of the individual elements. For the default case, each weight is
355-
named according to the column-names in X (if X is a dataframe).
387+
param-names and whose values are the names of the individual param-elements. For the default case, each weight
388+
is named according to the column-names in X (or "x{i}" if X is not a dataframe).
356389
357390
:param X: A dataframe or ndarray.
358391
:param output_dim: The number of output dimensions.
@@ -370,17 +403,18 @@ def module_factory(cls,
370403
columns = list(X.columns) if hasattr(X, 'columns') else [f'x{i}' for i in range(X.shape[1])]
371404

372405
module_param_names = {'bias': [], 'weight': []}
373-
for i in range(output_dim):
374-
module_param_names['bias'].append(f'y{i}__bias')
375-
module_param_names['weight'].append([f'y{i}__{c}' for c in columns])
406+
if output_dim == 1:
407+
# for most common case of 1d output, param names are just feature-names
408+
module_param_names['bias'].append('bias')
409+
module_param_names['weight'].append(columns)
410+
else:
411+
# if multi-output, we prefix with output idx:
412+
for i in range(output_dim):
413+
module_param_names['bias'].append(f'y{i}__bias')
414+
module_param_names['weight'].append([f'y{i}__{c}' for c in columns])
376415

377416
return module, module_param_names
378417

379-
@classmethod
380-
def penalty_from_module(cls, module: torch.nn.Module, penalty_multi: float) -> torch.Tensor:
381-
feature_dist = torch.distributions.Normal(loc=0, scale=1 / penalty_multi ** .5, validate_args=False)
382-
return -feature_dist.log_prob(module.weight).sum()
383-
384418
def _init_optimizer(self) -> torch.optim.Optimizer:
385419
return torch.optim.LBFGS(
386420
self.module_.parameters(), max_iter=10, line_search_fn='strong_wolfe', lr=.25
@@ -433,19 +467,27 @@ def _get_penalty(self) -> torch.Tensor:
433467
if not self.penalty:
434468
return torch.zeros(1, **get_to_kwargs(self.module_))
435469

470+
# standardize to dictionary with param-names:
436471
if isinstance(self.penalty, dict):
437-
penalty_multis = self.penalty
438472
if set(self.penalty) != set(self.module_.keys()):
439473
raise ValueError(
440474
f"``self.penalty.keys()`` is {set(self.penalty)}, but expected {set(self.module_.keys())}"
441475
)
442476
else:
443-
penalty_multis = {k: self.penalty for k in self.module_.keys()}
444-
445-
to_sum = [torch.tensor(0., **get_to_kwargs(self.module_))]
446-
for dp, module in self.module_.items():
477+
self.penalty = {k: self.penalty for k in self.module_.keys()}
478+
479+
# standardize to values = callables:
480+
for param_name in list(self.penalty):
481+
maybe_callable = self.penalty[param_name]
482+
if not callable(maybe_callable):
483+
# if not callable, it's a multiplier --i.e. L2's 'precision'
484+
self.penalty[param_name] = L2(precision=maybe_callable)
485+
486+
# call each:
487+
to_sum = []
488+
for param_name, penalty_fun in self.penalty.items():
447489
to_sum.append(
448-
self.penalty_from_module(module, penalty_multi=penalty_multis[dp])
490+
penalty_fun(self.module_[param_name], self._module_param_names_[param_name])
449491
)
450492
return torch.stack(to_sum).sum()
451493

foundry/penalty.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Union, Dict
2+
3+
import numpy as np
4+
import torch
5+
6+
from foundry.util import get_to_kwargs
7+
8+
9+
class Penalty:
10+
def __call__(self, module: torch.nn.Module, module_param_names: Dict[str, np.ndarray]) -> torch.Tensor:
11+
raise NotImplementedError
12+
13+
def __repr__(self) -> str:
14+
return f"{type(self).__name__}()"
15+
16+
17+
class L2(Penalty):
18+
"""
19+
Create a penalty that can be passed to ``Glm(penalty=)``.
20+
21+
:param precision: Module weights will have penalties equivalent to gaussian priors; this is the precision for that
22+
gaussian distribution. Can also be a dictionary with names corresponding to feature-names.
23+
:param mean: See above; the mean of this gaussian (default zero). Can be a dictionary with feature-names.
24+
"""
25+
26+
def __init__(self, precision: Union[float, dict], mean: Union[float, dict] = 0.):
27+
if not isinstance(mean, dict):
28+
mean = {'_default': mean}
29+
self.mean = mean
30+
if not isinstance(precision, dict):
31+
precision = {'_default': precision}
32+
self.precision = precision
33+
34+
def __call__(self, module: torch.nn.Module, module_param_names: Dict[str, np.ndarray]) -> torch.Tensor:
35+
if set(module_param_names) != {'bias', 'weight'}:
36+
raise NotImplementedError(f"{type(self)} not implemented for module with params!={'bias', 'weight'}")
37+
38+
to = get_to_kwargs(module)
39+
40+
feature_nms = list(module_param_names['weight'].reshape(-1))
41+
try:
42+
means = torch.tensor([self.mean.get(nm, self.mean['_default']) for nm in feature_nms], **to)
43+
precisions = torch.tensor([self.precision.get(nm, self.precision['_default']) for nm in feature_nms], **to)
44+
except KeyError as e:
45+
raise RuntimeError(
46+
f"mean/precision passed to {type(self)} should be dict with keys '_default' or:\n{feature_nms}"
47+
) from e
48+
feature_dist = torch.distributions.Normal(loc=means, scale=1 / precisions ** .5, validate_args=False)
49+
return -feature_dist.log_prob(module.weight.view(-1)).sum()

0 commit comments

Comments
 (0)