Skip to content

Commit 700c280

Browse files
ENH add weights for Group Lasso (#223)
Co-authored-by: mathurinm <[email protected]>
1 parent d4d7b77 commit 700c280

File tree

7 files changed

+129
-55
lines changed

7 files changed

+129
-55
lines changed

.gitignore

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
*.o
55
*.c
66
*.html
7-
celer.egg-info/*
8-
.eggs/*
7+
celer.egg-info
8+
.eggs
99

1010
# Python precompilation
1111
*pyc
12+
*pyd
13+
14+
# build
15+
build
1216

1317

1418
# cache
15-
.pytest_cache/*
19+
.pytest_cache
20+
__pycache__
1621

1722
doc/*
1823
coverage/*

celer/dropin_sklearn.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -711,9 +711,9 @@ class GroupLasso(Lasso_sklearn):
711711
712712
The optimization objective for the Group Lasso is::
713713
714-
(1 / (2 * n_samples)) * ||y - X w||^2_2 + alpha * \sum_g ||w_g||_2
714+
(1 / (2 * n_samples)) * ||y - X w||^2_2 + alpha * \sum_g weights_g ||w_g||_2
715715
716-
where `w_g` is the weight vector of group number `g`.
716+
where `w_g` are the regression coefficients of group number `g`.
717717
718718
Parameters
719719
----------
@@ -752,6 +752,10 @@ class GroupLasso(Lasso_sklearn):
752752
fit_intercept : bool, optional (default=True)
753753
Whether or not to fit an intercept.
754754
755+
weights : array, shape (n_groups,), optional (default=None)
756+
Strictly positive weights used in the L2 penalty part of the
757+
GroupLasso objective. If None, weights equal to 1 are used.
758+
755759
warm_start : bool, optional (default=False)
756760
When set to True, reuse the solution of the previous call to fit as
757761
initialization, otherwise, just erase the previous solution.
@@ -801,7 +805,7 @@ class GroupLasso(Lasso_sklearn):
801805

802806
def __init__(self, groups=1, alpha=1., max_iter=100,
803807
max_epochs=50000, p0=10, verbose=0, tol=1e-4, prune=True,
804-
fit_intercept=True, warm_start=False):
808+
fit_intercept=True, weights=None, warm_start=False):
805809
super(GroupLasso, self).__init__(
806810
alpha=alpha, tol=tol, max_iter=max_iter,
807811
fit_intercept=fit_intercept,
@@ -811,6 +815,7 @@ def __init__(self, groups=1, alpha=1., max_iter=100,
811815
self.max_epochs = max_epochs
812816
self.p0 = p0
813817
self.prune = prune
818+
self.weights = weights
814819

815820
def path(self, X, y, alphas, coef_init=None, return_n_iter=True,
816821
**kwargs):
@@ -820,7 +825,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True,
820825
coef_init=coef_init, max_iter=self.max_iter,
821826
return_n_iter=return_n_iter, max_epochs=self.max_epochs,
822827
p0=self.p0, verbose=self.verbose, tol=self.tol, prune=self.prune,
823-
X_scale=kwargs.get('X_scale', None),
828+
weights=self.weights, X_scale=kwargs.get('X_scale', None),
824829
X_offset=kwargs.get('X_offset', None))
825830

826831
return results

celer/group_fast.pyx

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,21 @@ cdef:
2525
@cython.cdivision(True)
2626
cpdef floating primal_grplasso(
2727
floating alpha, floating[:] R, int[::1] grp_ptr,
28-
int[::1] grp_indices, floating[:] w):
28+
int[::1] grp_indices, floating[:] w, floating[:] weights):
2929
cdef floating nrm = 0.
3030
cdef int j, k, g
3131
cdef int n_samples = R.shape[0]
3232
cdef int n_groups = grp_ptr.shape[0] - 1
3333
cdef floating p_obj = fnrm2(&n_samples, &R[0], &inc) ** 2 / (2 * n_samples)
34+
3435
for g in range(n_groups):
3536
nrm = 0.
37+
3638
for k in range(grp_ptr[g], grp_ptr[g + 1]):
3739
j = grp_indices[k]
3840
nrm += w[j] ** 2
39-
p_obj += alpha * sqrt(nrm)
41+
p_obj += alpha * sqrt(nrm) * weights[g]
42+
4043
return p_obj
4144

4245

@@ -47,7 +50,7 @@ cpdef floating dnorm_grp(
4750
bint is_sparse, floating[::1] theta, int[::1] grp_ptr,
4851
int[::1] grp_indices, floating[::1, :] X, floating[::1] X_data,
4952
int[::1] X_indices, int[::1] X_indptr, floating[::1] X_mean,
50-
int ws_size, int[:] C, bint center):
53+
floating[:] weights, int ws_size, int[:] C, bint center):
5154
"""Dual norm in the group case, i.e. L2/infty ofter groups."""
5255
cdef floating Xj_theta, tmp
5356
cdef floating scal = 0.
@@ -63,6 +66,9 @@ cpdef floating dnorm_grp(
6366

6467
if ws_size == n_groups: # max over all groups
6568
for g in range(n_groups):
69+
if weights[g] == INFINITY:
70+
continue
71+
6672
tmp = 0
6773
for k in range(grp_ptr[g], grp_ptr[g + 1]):
6874
j = grp_indices[k]
@@ -79,10 +85,13 @@ cpdef floating dnorm_grp(
7985
&inc)
8086
tmp += Xj_theta ** 2
8187

82-
scal = max(scal, sqrt(tmp))
88+
scal = max(scal, sqrt(tmp) / weights[g])
8389

8490
else: # scaling only with features in C
8591
for g_idx in range(ws_size):
92+
if weights[g] == INFINITY:
93+
continue
94+
8695
g = C[g_idx]
8796
tmp = 0
8897
for k in range(grp_ptr[g], grp_ptr[g + 1]):
@@ -100,7 +109,7 @@ cpdef floating dnorm_grp(
100109
&inc)
101110
tmp += Xj_theta ** 2
102111

103-
scal = max(scal, sqrt(tmp))
112+
scal = max(scal, sqrt(tmp) / weights[g])
104113
return scal
105114

106115

@@ -110,9 +119,9 @@ cpdef floating dnorm_grp(
110119
cdef void set_prios_grp(
111120
bint is_sparse, int pb, floating[::1] theta, floating[::1, :] X,
112121
floating[::1] X_data, int[::1] X_indices, int[::1] X_indptr,
113-
floating[::1] norms_X_grp, int[::1] grp_ptr, int[::1] grp_indices,
114-
floating[::1] prios, int[::1] screened, floating radius,
115-
int * n_screened):
122+
floating[:] weights, floating[::1] norms_X_grp, int[::1] grp_ptr,
123+
int[::1] grp_indices, floating[::1] prios, int[::1] screened,
124+
floating radius, int * n_screened):
116125
cdef int i, j, k, g, startptr, endptr
117126
cdef floating nrm_Xgtheta, Xj_theta
118127
cdef int n_groups = grp_ptr.shape[0] - 1
@@ -134,7 +143,7 @@ cdef void set_prios_grp(
134143
else:
135144
Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc)
136145
nrm_Xgtheta += Xj_theta ** 2
137-
nrm_Xgtheta = sqrt(nrm_Xgtheta)
146+
nrm_Xgtheta = sqrt(nrm_Xgtheta) / weights[g]
138147

139148
prios[g] = (1. - nrm_Xgtheta) / norms_X_grp[g]
140149

@@ -150,8 +159,8 @@ cpdef celer_grp(
150159
int[::1] grp_ptr, floating[::1] X_data, int[::1] X_indices,
151160
int[::1] X_indptr, floating[::1] X_mean, floating[:] y, floating alpha,
152161
floating[:] w, floating[:] R, floating[::1] theta,
153-
floating[::1] norms_X_grp, floating tol, int max_iter, int max_epochs,
154-
int gap_freq=10, floating tol_ratio_inner=0.3, int p0=100,
162+
floating[::1] norms_X_grp, floating tol, floating[:] weights, int max_iter,
163+
int max_epochs, int gap_freq=10, floating tol_ratio_inner=0.3, int p0=100,
155164
bint prune=1, bint use_accel=1,
156165
bint verbose=0):
157166

@@ -225,7 +234,7 @@ cpdef celer_grp(
225234

226235
scal = dnorm_grp(
227236
is_sparse, theta, grp_ptr, grp_indices, X, X_data, X_indices,
228-
X_indptr, X_mean, n_groups, dummy_C, center)
237+
X_indptr, X_mean, weights, n_groups, dummy_C, center)
229238

230239
if scal > 1. :
231240
tmp = 1. / scal
@@ -234,11 +243,10 @@ cpdef celer_grp(
234243
d_obj = dual(pb, n_samples, alpha, norm_y2, &theta[0], &y[0])
235244

236245
if t > 0:
237-
pass
238246
# also test dual point returned by inner solver after 1st iter:
239247
scal = dnorm_grp(
240248
is_sparse, theta_inner, grp_ptr, grp_indices, X, X_data,
241-
X_indices, X_indptr, X_mean, n_groups, dummy_C, center)
249+
X_indices, X_indptr, X_mean, weights, n_groups, dummy_C, center)
242250
if scal > 1.:
243251
tmp = 1. / scal
244252
fscal(&n_samples, &tmp, &theta_inner[0], &inc)
@@ -254,7 +262,7 @@ cpdef celer_grp(
254262
highest_d_obj = d_obj
255263
# TODO implement a best_theta
256264

257-
p_obj = primal_grplasso(alpha, R, grp_ptr, grp_indices, w)
265+
p_obj = primal_grplasso(alpha, R, grp_ptr, grp_indices, w, weights)
258266
gap = p_obj - highest_d_obj
259267
gaps[t] = gap
260268

@@ -272,8 +280,9 @@ cpdef celer_grp(
272280
# radius = sqrt(gap / 2.) / alpha
273281

274282
set_prios_grp(
275-
is_sparse, pb, theta, X, X_data, X_indices, X_indptr, lc_groups,
276-
grp_ptr, grp_indices, prios, screened, radius, &n_screened)
283+
is_sparse, pb, theta, X, X_data, X_indices, X_indptr,
284+
weights, lc_groups, grp_ptr, grp_indices, prios, screened,
285+
radius, &n_screened)
277286

278287
if prune:
279288
nnz = 0
@@ -327,7 +336,7 @@ cpdef celer_grp(
327336

328337
scal = dnorm_grp(
329338
is_sparse, theta_inner, grp_ptr, grp_indices, X, X_data,
330-
X_indices, X_indptr, X_mean, ws_size, C, center)
339+
X_indices, X_indptr, X_mean, weights, ws_size, C, center)
331340

332341
if scal > 1. :
333342
tmp = 1. / scal
@@ -348,8 +357,8 @@ cpdef celer_grp(
348357
if epoch // gap_freq >= K:
349358
scal = dnorm_grp(
350359
is_sparse, thetacc, grp_ptr, grp_indices, X,
351-
X_data, X_indices, X_indptr, X_mean, ws_size, C,
352-
center)
360+
X_data, X_indices, X_indptr, X_mean, weights,
361+
ws_size, C, center)
353362

354363
if scal > 1.:
355364
tmp = 1. / scal
@@ -365,7 +374,8 @@ cpdef celer_grp(
365374

366375
if d_obj_in > highest_d_obj_in:
367376
highest_d_obj_in = d_obj_in
368-
p_obj_in = primal_grplasso(alpha, R, grp_ptr, grp_indices, w)
377+
378+
p_obj_in = primal_grplasso(alpha, R, grp_ptr, grp_indices, w, weights)
369379
gap_in = p_obj_in - highest_d_obj_in
370380

371381
if verbose_in:
@@ -402,7 +412,7 @@ cpdef celer_grp(
402412
norm_wg += w[j] ** 2
403413
norm_wg = sqrt(norm_wg)
404414
bst_scal = max(0.,
405-
1. - alpha / lc_groups[g] * n_samples / norm_wg)
415+
1. - alpha * weights[g] / lc_groups[g] * n_samples / norm_wg)
406416

407417
for k in range(grp_ptr[g + 1] - grp_ptr[g]):
408418
j = grp_indices[grp_ptr[g] + k]

celer/homotopy.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
146146

147147
if pb.lower() not in ("lasso", "logreg", "grouplasso"):
148148
raise ValueError("Unsupported problem %s" % pb)
149+
150+
n_groups = None # set n_groups to None for lasso and logreg
149151
if pb.lower() == "lasso":
150152
pb = LASSO
151153
elif pb.lower() == "logreg":
@@ -182,15 +184,7 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
182184

183185
X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X)
184186

185-
if weights is None:
186-
weights = np.ones(n_features).astype(X.dtype)
187-
elif (weights <= 0).any():
188-
raise ValueError("0 or negative weights are not supported.")
189-
elif weights.shape[0] != X.shape[1]:
190-
raise ValueError(
191-
"As many weights as features must be passed. "
192-
f"Expected {X.shape[1]}, got {weights.shape[0]}."
193-
)
187+
weights = _check_weights(weights, pb, X, n_groups)
194188

195189
if alphas is None:
196190
if pb == LASSO:
@@ -204,7 +198,7 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
204198
alpha_max = 0
205199
for g in range(n_groups):
206200
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
207-
alpha_max = max(alpha_max, norm(X_g.T @ y, ord=2))
201+
alpha_max = max(alpha_max, norm(X_g.T @ y / weights[g], ord=2))
208202
alpha_max /= n_samples
209203

210204
alphas = alpha_max * np.geomspace(1, eps, n_alphas,
@@ -287,17 +281,17 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
287281
scal = dnorm_grp(
288282
is_sparse, theta, grp_ptr, grp_indices, X_dense,
289283
X_data, X_indices, X_indptr, X_sparse_scaling,
290-
len(grp_ptr) - 1, np.zeros(1, dtype=np.int32),
284+
weights, len(grp_ptr) - 1, np.zeros(1, dtype=np.int32),
291285
X_sparse_scaling.any())
292286
theta /= scal
293287

294288
# celer modifies w, Xw, and theta in place:
295-
if pb == GRPLASSO: # TODO this if else scheme is complicated
289+
if pb == GRPLASSO:
290+
# TODO this if else scheme is complicated
296291
sol = celer_grp(
297292
is_sparse, LASSO, X_dense, grp_indices, grp_ptr, X_data,
298-
X_indices,
299-
X_indptr, X_sparse_scaling, y, alpha, w, Xw, theta,
300-
norms_X_grp, tol, max_iter, max_epochs, p0=p0,
293+
X_indices, X_indptr, X_sparse_scaling, y, alpha, w, Xw, theta,
294+
norms_X_grp, tol, weights, max_iter, max_epochs, p0=p0,
301295
prune=prune, verbose=verbose)
302296
elif pb == LASSO or (pb == LOGREG and not use_PN):
303297
sol = celer(
@@ -325,6 +319,26 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
325319
return results
326320

327321

322+
def _check_weights(weights, pb, X, n_groups):
323+
"""Handle weights cases."""
324+
if weights is None:
325+
n_weights = n_groups if pb == GRPLASSO else X.shape[1]
326+
weights = np.ones(n_weights, dtype=X.dtype)
327+
elif (weights <= 0).any():
328+
raise ValueError("0 or negative weights are not supported.")
329+
else:
330+
expected_n_weights = n_groups if pb == GRPLASSO else X.shape[1]
331+
feat_or_grp = "groups" if pb == GRPLASSO else "features"
332+
333+
if weights.shape[0] != expected_n_weights:
334+
raise ValueError(
335+
f"As many weights as {feat_or_grp} must be passed. "
336+
f"Expected {expected_n_weights}, got {weights.shape[0]}."
337+
)
338+
339+
return weights
340+
341+
328342
def _sparse_and_dense(X):
329343
if sparse.issparse(X):
330344
X_dense = np.empty([1, 1], order='F', dtype=X.data.dtype)

celer/tests/test_docstring_parameters.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pkgutil import walk_packages
99
from inspect import getsource
1010

11+
from numpydoc import docscrape
1112
import celer
1213

1314
# copied from sklearn.fixes
@@ -65,7 +66,6 @@ def get_name(func):
6566

6667
def check_parameters_match(func, doc=None):
6768
"""Check docstring, return list of incorrect results."""
68-
from numpydoc import docscrape
6969
incorrect = []
7070
name_ = get_name(func)
7171
if not name_.startswith('celer.'):
@@ -108,8 +108,6 @@ def check_parameters_match(func, doc=None):
108108
# @requires_numpydoc
109109
def test_docstring_parameters():
110110
"""Test module docstring formatting."""
111-
from numpydoc import docscrape
112-
113111
public_modules_ = public_modules[:]
114112

115113
incorrect = []

0 commit comments

Comments
 (0)