@@ -66,7 +66,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
6666 if self .ws_strategy == "subdiff" :
6767 opt = penalty .subdiff_distance (W , grad , all_feats )
6868 elif self .ws_strategy == "fixpoint" :
69- opt = dist_fix_point (W , grad , datafit , penalty , all_feats )
69+ opt = dist_fix_point_bcd (W , grad , datafit , penalty , all_feats )
7070 stop_crit = np .max (opt )
7171 if self .verbose :
7272 print (f"Stopping criterion max violation: { stop_crit :.2e} " )
@@ -150,7 +150,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
150150 if self .ws_strategy == "subdiff" :
151151 opt_ws = penalty .subdiff_distance (W , grad_ws , ws )
152152 elif self .ws_strategy == "fixpoint" :
153- opt_ws = dist_fix_point (
153+ opt_ws = dist_fix_point_bcd (
154154 W , grad_ws , lipschitz , datafit , penalty , ws
155155 )
156156
@@ -231,7 +231,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False)
231231
232232
233233@njit
234- def dist_fix_point (W , grad_ws , lipschitz , datafit , penalty , ws ):
234+ def dist_fix_point_bcd (W , grad_ws , lipschitz , datafit , penalty , ws ):
235235 """Compute the violation of the fixed point iterate schema.
236236
237237 Parameters
@@ -256,17 +256,20 @@ def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
256256
257257 Returns
258258 -------
259- dist_fix_point : array, shape (ws_size,)
259+ dist : array, shape (ws_size,)
260260 Contain the violation score for every feature.
261261 """
262- dist_fix_point = np .zeros (ws .shape [0 ])
262+ dist = np .zeros (ws .shape [0 ])
263+
263264 for idx , j in enumerate (ws ):
264- lcj = lipschitz [j ]
265- if lcj :
266- dist_fix_point [idx ] = norm (
267- W [j ] - penalty .prox_1feat (W [j ] - grad_ws [idx ] / lcj , 1. / lcj , j )
268- )
269- return dist_fix_point
265+ if lipschitz [j ] == 0. :
266+ continue
267+
268+ step_j = 1 / lipschitz [j ]
269+ dist [idx ] = norm (
270+ W [j ] - penalty .prox_1feat (W [j ] - step_j * grad_ws [idx ], step_j , j )
271+ )
272+ return dist
270273
271274
272275@njit
0 commit comments