44from numba import njit
55from scipy .sparse import issparse
66from skglm .solvers .base import BaseSolver
7+ from skglm .solvers .common import dist_fix_point_cd
78
89from sklearn .exceptions import ConvergenceWarning
910from skglm .utils .sparse_ops import _sparse_xj_dot
@@ -28,6 +29,9 @@ class ProxNewton(BaseSolver):
2829 tol : float, default 1e-4
2930 Tolerance for convergence.
3031
32+ ws_strategy : ('subdiff'|'fixpoint'), optional
33+ The score used to build the working set.
34+
3135 fit_intercept : bool, default True
3236 If ``True``, fits an unpenalized intercept.
3337
@@ -49,11 +53,13 @@ class ProxNewton(BaseSolver):
4953 """
5054
5155 def __init__ (self , p0 = 10 , max_iter = 20 , max_pn_iter = 1000 , tol = 1e-4 ,
52- fit_intercept = True , warm_start = False , verbose = 0 ):
56+ ws_strategy = "subdiff" , fit_intercept = True , warm_start = False ,
57+ verbose = 0 ):
5358 self .p0 = p0
5459 self .max_iter = max_iter
5560 self .max_pn_iter = max_pn_iter
5661 self .tol = tol
62+ self .ws_strategy = ws_strategy
5763 self .fit_intercept = fit_intercept
5864 self .warm_start = warm_start
5965 self .verbose = verbose
@@ -73,6 +79,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
7379 if is_sparse :
7480 X_bundles = (X .data , X .indptr , X .indices )
7581
82+ if self .ws_strategy == "fixpoint" :
83+ X_square = X .multiply (X ) if is_sparse else X ** 2
84+
7685 if len (w ) != n_features + self .fit_intercept :
7786 if self .fit_intercept :
7887 val_error_message = (
@@ -92,7 +101,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
92101 else :
93102 grad = _construct_grad (X , y , w [:n_features ], Xw , datafit , all_features )
94103
95- opt = penalty .subdiff_distance (w [:n_features ], grad , all_features )
104+ if self .ws_strategy == "subdiff" :
105+ opt = penalty .subdiff_distance (w [:n_features ], grad , all_features )
106+ elif self .ws_strategy == "fixpoint" :
107+ lipschitz = datafit .raw_hessian (y , Xw ) @ X_square
108+ opt = dist_fix_point_cd (
109+ w [:n_features ], grad , lipschitz , datafit , penalty , all_features
110+ )
96111
97112 # optimality of intercept
98113 if fit_intercept :
@@ -128,13 +143,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
128143 for pn_iter in range (self .max_pn_iter ):
129144 # find descent direction
130145 if is_sparse :
131- delta_w_ws , X_delta_w_ws = _descent_direction_s (
146+ delta_w_ws , X_delta_w_ws , lipschitz_ws = _descent_direction_s (
132147 * X_bundles , y , w , Xw , fit_intercept , grad_ws , datafit ,
133- penalty , ws , tol = EPS_TOL * tol_in )
148+ penalty , ws , tol = EPS_TOL * tol_in , ws_strategy = self . ws_strategy )
134149 else :
135- delta_w_ws , X_delta_w_ws = _descent_direction (
150+ delta_w_ws , X_delta_w_ws , lipschitz_ws = _descent_direction (
136151 X , y , w , Xw , fit_intercept , grad_ws , datafit ,
137- penalty , ws , tol = EPS_TOL * tol_in )
152+ penalty , ws , tol = EPS_TOL * tol_in , ws_strategy = self . ws_strategy )
138153
139154 # backtracking line search with inplace update of w, Xw
140155 if is_sparse :
@@ -147,7 +162,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
147162 delta_w_ws , X_delta_w_ws , ws )
148163
149164 # check convergence
150- opt_in = penalty .subdiff_distance (w , grad_ws , ws )
165+ if self .ws_strategy == "subdiff" :
166+ opt_in = penalty .subdiff_distance (w , grad_ws , ws )
167+ elif self .ws_strategy == "fixpoint" :
168+ opt_in = dist_fix_point_cd (
169+ w , grad_ws , lipschitz_ws , datafit , penalty , ws
170+ )
151171 stop_crit_in = np .max (opt_in )
152172
153173 if max (self .verbose - 1 , 0 ):
@@ -176,7 +196,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
176196
177197@njit
178198def _descent_direction (X , y , w_epoch , Xw_epoch , fit_intercept , grad_ws , datafit ,
179- penalty , ws , tol ):
199+ penalty , ws , tol , ws_strategy ):
180200 # Given:
181201 # 1) b = \nabla F(X w_epoch)
182202 # 2) D = \nabla^2 F(X w_epoch) <------> raw_hess
@@ -229,7 +249,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
229249 # TODO: can be improved by passing in w_ws but breaks for WeightedL1
230250 current_w = w_epoch .copy ()
231251 current_w [ws_intercept ] = w_ws
232- opt = penalty .subdiff_distance (current_w , past_grads , ws )
252+ if ws_strategy == "subdiff" :
253+ opt = penalty .subdiff_distance (current_w , past_grads , ws )
254+ elif ws_strategy == "fixpoint" :
255+ opt = dist_fix_point_cd (
256+ current_w , past_grads , lipschitz , datafit , penalty , ws
257+ )
233258 stop_crit = np .max (opt )
234259
235260 if fit_intercept :
@@ -239,13 +264,14 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
239264 break
240265
241266 # descent direction
242- return w_ws - w_epoch [ws_intercept ], X_delta_w_ws
267+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz
243268
244269
245270# sparse version of _descent_direction
246271@njit
247272def _descent_direction_s (X_data , X_indptr , X_indices , y , w_epoch ,
248- Xw_epoch , fit_intercept , grad_ws , datafit , penalty , ws , tol ):
273+ Xw_epoch , fit_intercept , grad_ws , datafit , penalty , ws , tol ,
274+ ws_strategy ):
249275 dtype = X_data .dtype
250276 raw_hess = datafit .raw_hessian (y , Xw_epoch )
251277
@@ -298,7 +324,12 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
298324 # TODO: could be improved by passing in w_ws
299325 current_w = w_epoch .copy ()
300326 current_w [ws_intercept ] = w_ws
301- opt = penalty .subdiff_distance (current_w , past_grads , ws )
327+ if ws_strategy == "subdiff" :
328+ opt = penalty .subdiff_distance (current_w , past_grads , ws )
329+ elif ws_strategy == "fixpoint" :
330+ opt = dist_fix_point_cd (
331+ current_w , past_grads , lipschitz , datafit , penalty , ws
332+ )
302333 stop_crit = np .max (opt )
303334
304335 if fit_intercept :
@@ -308,7 +339,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
308339 break
309340
310341 # descent direction
311- return w_ws - w_epoch [ws_intercept ], X_delta_w_ws
342+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz
312343
313344
314345@njit
0 commit comments