Skip to content

Commit 261fee0

Browse files
ENH - Add fixed-point distance strategy to build working sets in ProxNewton (scikit-learn-contrib#138)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent 4f1951c commit 261fee0

File tree

4 files changed

+53
-19
lines changed

4 files changed

+53
-19
lines changed

doc/changes/0.4.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
.. _changes_0_4:
22

33
Version 0.4 (in progress)
4-
---------------------------
4+
-------------------------
55
- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator <skglm.MCPRegression>` (PR: :gh:`184`)
66
- Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`)
77
- Add :ref:`LogSumPenalty <skglm.penalties.LogSumPenalty>` (PR: :gh:`#127`)
8+
- Add fixed-point distance to build working sets in :ref:`ProxNewton <skglm.solvers.ProxNewton>` solver (:gh:`138`)

skglm/solvers/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
3131
dist : array, shape (n_features,)
3232
Violation score for every feature.
3333
"""
34-
dist = np.zeros(ws.shape[0])
34+
dist = np.zeros(ws.shape[0], dtype=w.dtype)
3535

3636
for idx, j in enumerate(ws):
3737
if lipschitz[j] == 0.:

skglm/solvers/prox_newton.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numba import njit
55
from scipy.sparse import issparse
66
from skglm.solvers.base import BaseSolver
7+
from skglm.solvers.common import dist_fix_point_cd
78

89
from sklearn.exceptions import ConvergenceWarning
910
from skglm.utils.sparse_ops import _sparse_xj_dot
@@ -28,6 +29,9 @@ class ProxNewton(BaseSolver):
2829
tol : float, default 1e-4
2930
Tolerance for convergence.
3031
32+
ws_strategy : ('subdiff'|'fixpoint'), optional
33+
The score used to build the working set.
34+
3135
fit_intercept : bool, default True
3236
If ``True``, fits an unpenalized intercept.
3337
@@ -49,11 +53,13 @@ class ProxNewton(BaseSolver):
4953
"""
5054

5155
def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
52-
fit_intercept=True, warm_start=False, verbose=0):
56+
ws_strategy="subdiff", fit_intercept=True, warm_start=False,
57+
verbose=0):
5358
self.p0 = p0
5459
self.max_iter = max_iter
5560
self.max_pn_iter = max_pn_iter
5661
self.tol = tol
62+
self.ws_strategy = ws_strategy
5763
self.fit_intercept = fit_intercept
5864
self.warm_start = warm_start
5965
self.verbose = verbose
@@ -73,6 +79,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
7379
if is_sparse:
7480
X_bundles = (X.data, X.indptr, X.indices)
7581

82+
if self.ws_strategy == "fixpoint":
83+
X_square = X.multiply(X) if is_sparse else X ** 2
84+
7685
if len(w) != n_features + self.fit_intercept:
7786
if self.fit_intercept:
7887
val_error_message = (
@@ -92,7 +101,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
92101
else:
93102
grad = _construct_grad(X, y, w[:n_features], Xw, datafit, all_features)
94103

95-
opt = penalty.subdiff_distance(w[:n_features], grad, all_features)
104+
if self.ws_strategy == "subdiff":
105+
opt = penalty.subdiff_distance(w[:n_features], grad, all_features)
106+
elif self.ws_strategy == "fixpoint":
107+
lipschitz = datafit.raw_hessian(y, Xw) @ X_square
108+
opt = dist_fix_point_cd(
109+
w[:n_features], grad, lipschitz, datafit, penalty, all_features
110+
)
96111

97112
# optimality of intercept
98113
if fit_intercept:
@@ -128,13 +143,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
128143
for pn_iter in range(self.max_pn_iter):
129144
# find descent direction
130145
if is_sparse:
131-
delta_w_ws, X_delta_w_ws = _descent_direction_s(
146+
delta_w_ws, X_delta_w_ws, lipschitz_ws = _descent_direction_s(
132147
*X_bundles, y, w, Xw, fit_intercept, grad_ws, datafit,
133-
penalty, ws, tol=EPS_TOL*tol_in)
148+
penalty, ws, tol=EPS_TOL*tol_in, ws_strategy=self.ws_strategy)
134149
else:
135-
delta_w_ws, X_delta_w_ws = _descent_direction(
150+
delta_w_ws, X_delta_w_ws, lipschitz_ws = _descent_direction(
136151
X, y, w, Xw, fit_intercept, grad_ws, datafit,
137-
penalty, ws, tol=EPS_TOL*tol_in)
152+
penalty, ws, tol=EPS_TOL*tol_in, ws_strategy=self.ws_strategy)
138153

139154
# backtracking line search with inplace update of w, Xw
140155
if is_sparse:
@@ -147,7 +162,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
147162
delta_w_ws, X_delta_w_ws, ws)
148163

149164
# check convergence
150-
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
165+
if self.ws_strategy == "subdiff":
166+
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
167+
elif self.ws_strategy == "fixpoint":
168+
opt_in = dist_fix_point_cd(
169+
w, grad_ws, lipschitz_ws, datafit, penalty, ws
170+
)
151171
stop_crit_in = np.max(opt_in)
152172

153173
if max(self.verbose-1, 0):
@@ -176,7 +196,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
176196

177197
@njit
178198
def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
179-
penalty, ws, tol):
199+
penalty, ws, tol, ws_strategy):
180200
# Given:
181201
# 1) b = \nabla F(X w_epoch)
182202
# 2) D = \nabla^2 F(X w_epoch) <------> raw_hess
@@ -229,7 +249,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
229249
# TODO: can be improved by passing in w_ws but breaks for WeightedL1
230250
current_w = w_epoch.copy()
231251
current_w[ws_intercept] = w_ws
232-
opt = penalty.subdiff_distance(current_w, past_grads, ws)
252+
if ws_strategy == "subdiff":
253+
opt = penalty.subdiff_distance(current_w, past_grads, ws)
254+
elif ws_strategy == "fixpoint":
255+
opt = dist_fix_point_cd(
256+
current_w, past_grads, lipschitz, datafit, penalty, ws
257+
)
233258
stop_crit = np.max(opt)
234259

235260
if fit_intercept:
@@ -239,13 +264,14 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
239264
break
240265

241266
# descent direction
242-
return w_ws - w_epoch[ws_intercept], X_delta_w_ws
267+
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz
243268

244269

245270
# sparse version of _descent_direction
246271
@njit
247272
def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
248-
Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol):
273+
Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol,
274+
ws_strategy):
249275
dtype = X_data.dtype
250276
raw_hess = datafit.raw_hessian(y, Xw_epoch)
251277

@@ -298,7 +324,12 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
298324
# TODO: could be improved by passing in w_ws
299325
current_w = w_epoch.copy()
300326
current_w[ws_intercept] = w_ws
301-
opt = penalty.subdiff_distance(current_w, past_grads, ws)
327+
if ws_strategy == "subdiff":
328+
opt = penalty.subdiff_distance(current_w, past_grads, ws)
329+
elif ws_strategy == "fixpoint":
330+
opt = dist_fix_point_cd(
331+
current_w, past_grads, lipschitz, datafit, penalty, ws
332+
)
302333
stop_crit = np.max(opt)
303334

304335
if fit_intercept:
@@ -308,7 +339,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
308339
break
309340

310341
# descent direction
311-
return w_ws - w_epoch[ws_intercept], X_delta_w_ws
342+
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz
312343

313344

314345
@njit

skglm/tests/test_prox_newton.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
import numpy as np
3-
from itertools import product
43
from sklearn.linear_model import LogisticRegression
54

65
from skglm.penalties import L1
@@ -11,8 +10,10 @@
1110
from skglm.utils.data import make_correlated_data
1211

1312

14-
@pytest.mark.parametrize("X_density, fit_intercept", product([1., 0.5], [True, False]))
15-
def test_pn_vs_sklearn(X_density, fit_intercept):
13+
@pytest.mark.parametrize("X_density", [1., 0.5])
14+
@pytest.mark.parametrize("fit_intercept", [True, False])
15+
@pytest.mark.parametrize("ws_strategy", ["subdiff", "fixpoint"])
16+
def test_pn_vs_sklearn(X_density, fit_intercept, ws_strategy):
1617
n_samples, n_features = 12, 25
1718
rho = 1e-1
1819

@@ -30,7 +31,8 @@ def test_pn_vs_sklearn(X_density, fit_intercept):
3031

3132
log_datafit = compiled_clone(Logistic())
3233
l1_penalty = compiled_clone(L1(alpha))
33-
prox_solver = ProxNewton(fit_intercept=fit_intercept, tol=1e-12)
34+
prox_solver = ProxNewton(
35+
fit_intercept=fit_intercept, tol=1e-12, ws_strategy=ws_strategy)
3436
w = prox_solver.solve(X, y, log_datafit, l1_penalty)[0]
3537

3638
np.testing.assert_allclose(w[:n_features], sk_log_reg.coef_.flatten())

0 commit comments

Comments
 (0)