Skip to content

Commit 1b699a3

Browse files
authored
API - make BaseDatafit and BasePenalty regular classes (scikit-learn-contrib#205)
1 parent c67a885 commit 1b699a3

File tree

5 files changed

+20
-177
lines changed

5 files changed

+20
-177
lines changed

doc/changes/0.4.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Version 0.4 (in progress)
55
- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator <skglm.MCPRegression>` (PR: :gh:`184`)
66
- Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`)
77
- Add :ref:`LogSumPenalty <skglm.penalties.LogSumPenalty>` (PR: :gh:`#127`)
8+
- Remove abstract methods in ``BaseDatafit`` and ``BasePenalty`` to make solver/penalty/datafit compatibility check easier (PR :gh:`#205`)
89
- Add fixed-point distance to build working sets in :ref:`ProxNewton <skglm.solvers.ProxNewton>` solver (:gh:`138`)

doc/tutorials/add_datafit.rst

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,17 @@ They can then be passed to a :class:`~skglm.GeneralizedLinearEstimator`.
3030
)
3131
3232
33-
A ``Datafit`` is a jitclass which must inherit from the ``BaseDatafit`` class:
33+
A ``Datafit`` is a jitclass that must inherit from the ``BaseDatafit`` class:
3434

35-
.. literalinclude:: ../skglm/datafits/base.py
35+
.. literalinclude:: ../../skglm/datafits/base.py
3636
:pyobject: BaseDatafit
3737

3838

39-
To define a custom datafit, you need to implement the methods declared in the ``BaseDatafit`` class.
40-
One needs to overload at least the ``value`` and ``gradient`` methods for skglm to support the datafit.
39+
To define a custom datafit, you need to inherit from ``BaseDatafit`` class and implement methods required by the targeted solver.
40+
These methods can be found in the solver documentation.
4141
Optionally, overloading the methods with the suffix ``_sparse`` adds support for sparse datasets (CSC matrix).
42-
As an example, we show how to implement the Poisson datafit in skglm.
42+
43+
This tutorial shows how to implement :ref:`Poisson <skglm.datafits.Poisson>` datafit to be fitted with :ref:`ProxNewton <skglm.solvers.ProxNewton>` solver.
4344

4445

4546
A case in point: defining Poisson datafit
@@ -104,18 +105,16 @@ For the Poisson datafit, this yields
104105
.. math::
105106
\frac{\partial F(\beta)}{\partial \beta_j} = \frac{1}{n}
106107
\sum_{i=1}^n X_{i,j} \left(
107-
\exp([X\beta]_i) - y
108+
\exp([X\beta]_i) - y
108109
\right)
109110
\ .
110111
111112
112113
When implementing these quantities in the ``Poisson`` datafit class, this gives:
113114

114-
.. literalinclude:: ../skglm/datafits/single_task.py
115+
.. literalinclude:: ../../skglm/datafits/single_task.py
115116
:pyobject: Poisson
116117

117118

118119
Note that we have not initialized any quantities in the ``initialize`` method.
119-
Usually it serves to compute a Lipschitz constant of the datafit, whose inverse is used by the solver as a step size.
120-
However, in this example, the Poisson datafit has no Lipschitz constant since the eigenvalues of the Hessian matrix are unbounded.
121-
This implies that a step size is not known in advance and a line search has to be performed at every epoch by the solver.
120+
Usually, it serves to compute datafit attributes specific to a dataset ``X, y`` for computational efficiency, for example the computation of ``X.T @ y`` in :ref:`Quadratic <skglm.datafits.Quadratic>` datafit.

doc/tutorials/add_penalty.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ skglm supports any arbitrary proximable penalty.
1010

1111
It is implemented as a jitclass which must inherit from the ``BasePenalty`` class:
1212

13-
.. literalinclude:: ../skglm/penalties/base.py
13+
.. literalinclude:: ../../skglm/penalties/base.py
1414
:pyobject: BasePenalty
1515

16-
To implement your own penalty, you only need to define a new jitclass, inheriting from ``BasePenalty`` and define how its value, proximal operator, distance to subdifferential (for KKT violation) and penalized features are computed.
16+
To implement your own penalty, you only need to define a new jitclass, inheriting from ``BasePenalty`` and implement the methods required by the targeted solver.
17+
Theses methods can be found in the solver documentation.
18+
1719

1820
A case in point: defining L1 penalty
1921
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -66,6 +68,6 @@ Note that since ``lambda`` is a reserved keyword in Python, ``alpha`` in skglm c
6668
When putting all together, this gives the implementation of the ``L1`` penalty:
6769

6870

69-
.. literalinclude:: ../skglm/penalties/separable.py
71+
.. literalinclude:: ../../skglm/penalties/separable.py
7072
:pyobject: L1
7173

skglm/datafits/base.py

Lines changed: 3 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from abc import abstractmethod
21

3-
4-
class BaseDatafit():
2+
class BaseDatafit:
53
"""Base class for datafits."""
64

7-
@abstractmethod
85
def get_spec(self):
96
"""Specify the numba types of the class attributes.
107
@@ -14,7 +11,6 @@ def get_spec(self):
1411
spec to be passed to Numba jitclass to compile the class.
1512
"""
1613

17-
@abstractmethod
1814
def params_to_dict(self):
1915
"""Get the parameters to initialize an instance of the class.
2016
@@ -24,7 +20,6 @@ def params_to_dict(self):
2420
The parameters to instantiate an object of the class.
2521
"""
2622

27-
@abstractmethod
2823
def initialize(self, X, y):
2924
"""Pre-computations before fitting on X and y.
3025
@@ -37,9 +32,7 @@ def initialize(self, X, y):
3732
Target vector.
3833
"""
3934

40-
@abstractmethod
41-
def initialize_sparse(
42-
self, X_data, X_indptr, X_indices, y):
35+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
4336
"""Pre-computations before fitting on X and y when X is a sparse matrix.
4437
4538
Parameters
@@ -57,7 +50,6 @@ def initialize_sparse(
5750
Target vector.
5851
"""
5952

60-
@abstractmethod
6153
def value(self, y, w, Xw):
6254
"""Value of datafit at vector w.
6355
@@ -78,68 +70,10 @@ def value(self, y, w, Xw):
7870
The datafit value at vector w.
7971
"""
8072

81-
@abstractmethod
82-
def gradient_scalar(self, X, y, w, Xw, j):
83-
"""Gradient with respect to j-th coordinate of w.
84-
85-
Parameters
86-
----------
87-
X : array, shape (n_samples, n_features)
88-
Design matrix.
89-
90-
y : array, shape (n_samples,)
91-
Target vector.
92-
93-
w : array, shape (n_features,)
94-
Coefficient vector.
95-
96-
Xw : array, shape (n_samples,)
97-
Model fit.
98-
99-
j : int
100-
The coordinate at which the gradient is evaluated.
101-
102-
Returns
103-
-------
104-
gradient : float
105-
The gradient of the datafit with respect to the j-th coordinate of w.
106-
"""
107-
108-
@abstractmethod
109-
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
110-
"""Gradient with respect to j-th coordinate of w when X is sparse.
111-
112-
Parameters
113-
----------
114-
X_data : array, shape (n_elements,)
115-
`data` attribute of the sparse CSC matrix X.
116-
117-
X_indptr : array, shape (n_features + 1,)
118-
`indptr` attribute of the sparse CSC matrix X.
119-
120-
X_indices : array, shape (n_elements,)
121-
`indices` attribute of the sparse CSC matrix X.
122-
123-
y : array, shape (n_samples,)
124-
Target vector.
125-
126-
Xw: array, shape (n_samples,)
127-
Model fit.
128-
129-
j : int
130-
The dimension along which the gradient is evaluated.
131-
132-
Returns
133-
-------
134-
gradient : float
135-
The gradient of the datafit with respect to the j-th coordinate of w.
136-
"""
137-
13873

139-
class BaseMultitaskDatafit():
74+
class BaseMultitaskDatafit:
14075
"""Base class for multitask datafits."""
14176

142-
@abstractmethod
14377
def get_spec(self):
14478
"""Specify the numba types of the class attributes.
14579
@@ -149,7 +83,6 @@ def get_spec(self):
14983
spec to be passed to Numba jitclass to compile the class.
15084
"""
15185

152-
@abstractmethod
15386
def params_to_dict(self):
15487
"""Get the parameters to initialize an instance of the class.
15588
@@ -159,7 +92,6 @@ def params_to_dict(self):
15992
The parameters to instantiate an object of the class.
16093
"""
16194

162-
@abstractmethod
16395
def initialize(self, X, Y):
16496
"""Store useful values before fitting on X and Y.
16597
@@ -172,7 +104,6 @@ def initialize(self, X, Y):
172104
Multitask target.
173105
"""
174106

175-
@abstractmethod
176107
def initialize_sparse(self, X_data, X_indptr, X_indices, Y):
177108
"""Store useful values before fitting on X and Y, when X is sparse.
178109
@@ -191,7 +122,6 @@ def initialize_sparse(self, X_data, X_indptr, X_indices, Y):
191122
Target matrix.
192123
"""
193124

194-
@abstractmethod
195125
def value(self, Y, W, XW):
196126
"""Value of datafit at matrix W.
197127
@@ -211,60 +141,3 @@ def value(self, Y, W, XW):
211141
value : float
212142
The datafit value evaluated at matrix W.
213143
"""
214-
215-
@abstractmethod
216-
def gradient_j(self, X, Y, W, XW, j):
217-
"""Gradient with respect to j-th coordinate of W.
218-
219-
Parameters
220-
----------
221-
X : array, shape (n_samples, n_features)
222-
Design matrix.
223-
224-
Y : array, shape (n_samples, n_tasks)
225-
Target matrix.
226-
227-
W : array, shape (n_features, n_tasks)
228-
Coefficient matrix.
229-
230-
XW : array, shape (n_samples, n_tasks)
231-
Model fit.
232-
233-
j : int
234-
The coordinate along which the gradient is evaluated.
235-
236-
Returns
237-
-------
238-
gradient : array, shape (n_tasks,)
239-
The gradient of the datafit with respect to the j-th coordinate of W.
240-
"""
241-
242-
@abstractmethod
243-
def gradient_j_sparse(self, X_data, X_indptr, X_indices, Y, XW, j):
244-
"""Gradient with respect to j-th coordinate of W when X is sparse.
245-
246-
Parameters
247-
----------
248-
X_data : array-like
249-
`data` attribute of the sparse CSC matrix X.
250-
251-
X_indptr : array-like
252-
`indptr` attribute of the sparse CSC matrix X.
253-
254-
X_indices : array-like
255-
`indices` attribute of the sparse CSC matrix X.
256-
257-
Y : array, shape (n_samples, n_tasks)
258-
Target matrix.
259-
260-
XW : array, shape (n_samples, n_tasks)
261-
Model fit.
262-
263-
j : int
264-
The coordinate along which the gradient is evaluated.
265-
266-
Returns
267-
-------
268-
gradient : array, shape (n_tasks,)
269-
The gradient of the datafit with respect to the j-th coordinate of W.
270-
"""

skglm/penalties/base.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from abc import abstractmethod
21

3-
4-
class BasePenalty():
2+
class BasePenalty:
53
"""Base class for penalty subclasses."""
64

7-
@abstractmethod
85
def get_spec(self):
96
"""Specify the numba types of the class attributes.
107
@@ -14,7 +11,6 @@ def get_spec(self):
1411
spec to be passed to Numba jitclass to compile the class.
1512
"""
1613

17-
@abstractmethod
1814
def params_to_dict(self):
1915
"""Get the parameters to initialize an instance of the class.
2016
@@ -24,39 +20,11 @@ def params_to_dict(self):
2420
The parameters to instantiate an object of the class.
2521
"""
2622

27-
@abstractmethod
2823
def value(self, w):
2924
"""Value of penalty at vector w."""
3025

31-
@abstractmethod
32-
def prox_1d(self, value, stepsize, j):
33-
"""Proximal operator of penalty for feature j."""
34-
35-
@abstractmethod
36-
def subdiff_distance(self, w, grad, ws):
37-
"""Distance of negative gradient to subdifferential at w for features in `ws`.
38-
39-
Parameters
40-
----------
41-
w: array, shape (n_features,)
42-
Coefficient vector.
43-
44-
grad: array, shape (ws.shape[0],)
45-
Gradient of the datafit at w, restricted to features in `ws`.
46-
47-
ws: array, shape (ws_size,)
48-
Indices of features in the working set.
49-
50-
Returns
51-
-------
52-
distances: array, shape (ws.shape[0],)
53-
The distances to the subdifferential.
54-
"""
55-
56-
@abstractmethod
5726
def is_penalized(self, n_features):
5827
"""Return a binary mask with the penalized features."""
5928

60-
@abstractmethod
6129
def generalized_support(self, w):
62-
r"""Return a mask which is True for coefficients in the generalized support."""
30+
"""Return a mask which is True for coefficients in the generalized support."""

0 commit comments

Comments
 (0)