Skip to content

Commit 57c6091

Browse files
authored
API - Expose get_lipschitz in QuadraticGroup (scikit-learn-contrib#200)
1 parent 59f91f7 commit 57c6091

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

skglm/datafits/group.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ class QuadraticGroup(BaseDatafit):
2222
grp_ptr : array, shape (n_groups + 1,)
2323
The group pointers such that two consecutive elements delimit
2424
the indices of a group in ``grp_indices``.
25-
26-
lipschitz : array, shape (n_groups,)
27-
The lipschitz constants for each group.
2825
"""
2926

3027
def __init__(self, grp_ptr, grp_indices):
@@ -34,15 +31,14 @@ def get_spec(self):
3431
spec = (
3532
('grp_ptr', int32[:]),
3633
('grp_indices', int32[:]),
37-
('lipschitz', float64[:])
3834
)
3935
return spec
4036

4137
def params_to_dict(self):
4238
return dict(grp_ptr=self.grp_ptr,
4339
grp_indices=self.grp_indices)
4440

45-
def initialize(self, X, y):
41+
def get_lipschitz(self, X, y):
4642
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
4743
n_groups = len(grp_ptr) - 1
4844

@@ -52,7 +48,7 @@ def initialize(self, X, y):
5248
X_g = X[:, grp_g_indices]
5349
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
5450

55-
self.lipschitz = lipschitz
51+
return lipschitz
5652

5753
def value(self, y, w, Xw):
5854
return norm(y - Xw) ** 2 / (2 * len(y))

skglm/solvers/group_bcd.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6767
raise ValueError(val_error_message)
6868

6969
datafit.initialize(X, y)
70+
lipschitz = datafit.get_lipschitz(X, y)
71+
7072
all_groups = np.arange(n_groups)
7173
p_objs_out = np.zeros(self.max_iter)
7274
stop_crit = 0. # prevent ref before assign when max_iter == 0
@@ -100,7 +102,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
100102

101103
for epoch in range(self.max_epochs):
102104
# inplace update of w and Xw
103-
_bcd_epoch(X, y, w[:n_features], Xw, datafit, penalty, ws)
105+
_bcd_epoch(X, y, w[:n_features], Xw, lipschitz, datafit, penalty, ws)
104106

105107
# update intercept
106108
if self.fit_intercept:
@@ -140,15 +142,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
140142

141143

142144
@njit
143-
def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
145+
def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws):
144146
# perform a single BCD epoch on groups in ws
145147
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices
146148

147149
for g in ws:
148150
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
149151
old_w_g = w[grp_g_indices].copy()
150152

151-
lipschitz_g = datafit.lipschitz[g]
153+
lipschitz_g = lipschitz[g]
152154
grad_g = datafit.gradient_g(X, y, w, Xw, g)
153155

154156
w[grp_g_indices] = penalty.prox_1group(

0 commit comments

Comments
 (0)