Skip to content

Commit 59f91f7

Browse files
authored
MNT - Refactor dist_fix_point function (scikit-learn-contrib#194)
1 parent 920cd41 commit 59f91f7

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

skglm/solvers/anderson_cd.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from numba import njit
33
from scipy import sparse
44
from sklearn.utils import check_array
5-
from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point
5+
from skglm.solvers.common import (
6+
construct_grad, construct_grad_sparse, dist_fix_point_cd
7+
)
68
from skglm.solvers.base import BaseSolver
79
from skglm.utils.anderson import AndersonAcceleration
810

@@ -104,7 +106,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
104106
if self.ws_strategy == "subdiff":
105107
opt = penalty.subdiff_distance(w[:n_features], grad, all_feats)
106108
elif self.ws_strategy == "fixpoint":
107-
opt = dist_fix_point(
109+
opt = dist_fix_point_cd(
108110
w[:n_features], grad, lipschitz, datafit, penalty, all_feats
109111
)
110112

@@ -181,7 +183,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
181183
if self.ws_strategy == "subdiff":
182184
opt_ws = penalty.subdiff_distance(w[:n_features], grad_ws, ws)
183185
elif self.ws_strategy == "fixpoint":
184-
opt_ws = dist_fix_point(
186+
opt_ws = dist_fix_point_cd(
185187
w[:n_features], grad_ws, lipschitz, datafit, penalty, ws
186188
)
187189

skglm/solvers/common.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@njit
6-
def dist_fix_point(w, grad_ws, lipschitz, datafit, penalty, ws):
6+
def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
77
"""Compute the violation of the fixed point iterate scheme.
88
99
Parameters
@@ -28,16 +28,20 @@ def dist_fix_point(w, grad_ws, lipschitz, datafit, penalty, ws):
2828
2929
Returns
3030
-------
31-
dist_fix_point : array, shape (n_features,)
31+
dist : array, shape (n_features,)
3232
Violation score for every feature.
3333
"""
34-
dist_fix_point = np.zeros(ws.shape[0])
34+
dist = np.zeros(ws.shape[0])
35+
3536
for idx, j in enumerate(ws):
36-
lcj = lipschitz[j]
37-
if lcj != 0:
38-
dist_fix_point[idx] = np.abs(
39-
w[j] - penalty.prox_1d(w[j] - grad_ws[idx] / lcj, 1. / lcj, j))
40-
return dist_fix_point
37+
if lipschitz[j] == 0.:
38+
continue
39+
40+
step_j = 1 / lipschitz[j]
41+
dist[idx] = np.abs(
42+
w[j] - penalty.prox_1d(w[j] - step_j * grad_ws[idx], step_j, j)
43+
)
44+
return dist
4145

4246

4347
@njit

skglm/solvers/multitask_bcd.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
6666
if self.ws_strategy == "subdiff":
6767
opt = penalty.subdiff_distance(W, grad, all_feats)
6868
elif self.ws_strategy == "fixpoint":
69-
opt = dist_fix_point(W, grad, datafit, penalty, all_feats)
69+
opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats)
7070
stop_crit = np.max(opt)
7171
if self.verbose:
7272
print(f"Stopping criterion max violation: {stop_crit:.2e}")
@@ -150,7 +150,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
150150
if self.ws_strategy == "subdiff":
151151
opt_ws = penalty.subdiff_distance(W, grad_ws, ws)
152152
elif self.ws_strategy == "fixpoint":
153-
opt_ws = dist_fix_point(
153+
opt_ws = dist_fix_point_bcd(
154154
W, grad_ws, lipschitz, datafit, penalty, ws
155155
)
156156

@@ -231,7 +231,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False)
231231

232232

233233
@njit
234-
def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
234+
def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
235235
"""Compute the violation of the fixed point iterate schema.
236236
237237
Parameters
@@ -256,17 +256,20 @@ def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
256256
257257
Returns
258258
-------
259-
dist_fix_point : array, shape (ws_size,)
259+
dist : array, shape (ws_size,)
260260
Contain the violation score for every feature.
261261
"""
262-
dist_fix_point = np.zeros(ws.shape[0])
262+
dist = np.zeros(ws.shape[0])
263+
263264
for idx, j in enumerate(ws):
264-
lcj = lipschitz[j]
265-
if lcj:
266-
dist_fix_point[idx] = norm(
267-
W[j] - penalty.prox_1feat(W[j] - grad_ws[idx] / lcj, 1. / lcj, j)
268-
)
269-
return dist_fix_point
265+
if lipschitz[j] == 0.:
266+
continue
267+
268+
step_j = 1 / lipschitz[j]
269+
dist[idx] = norm(
270+
W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j)
271+
)
272+
return dist
270273

271274

272275
@njit

0 commit comments

Comments
 (0)