@@ -376,13 +376,13 @@ def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
376376 mat = np .zeros ((m , 2 ))
377377
378378 # if a_sorted[i, n-1] >= a_sorted[i, n-2] + eps, then the projection is [0,...,0,eps]
379- done = False
379+ done = early_done = False
380380 active = np .array ([1 ] * m )
381381 after_vec = np .zeros ((m , n ))
382382 proj = a_sorted .copy ()
383383 j = n - 2
384384 while j >= 0 :
385- mat [:, 0 ] = mat [:, 0 ] + a_sorted [:, j + 1 ] # = sum(a_sorted[: i] : i = j + 1,...,n-1
385+ mat [:, 0 ] += a_sorted [:, j + 1 ] # = sum(a_sorted[: i] : i = j + 1,...,n-1
386386 mat [:, 1 ] = a_sorted [:, j ] * (n - j - 1 ) + eps
387387 # Find the max in each problem max{ sum{a_sorted[:, i] : i=j+1,..,n-1} , a_sorted[:, j] * (n-j-1) + eps }
388388 row_maxes = np .max (mat , axis = 1 )
@@ -396,21 +396,29 @@ def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
396396 # has to be reduced is delta
397397 delta = (mat [:, 0 ] - eps ) / (n - j - 1 )
398398 # The vector of reductions
399- delta_vec = np .array ([delta ] * (n - j - 1 ))
400- delta_vec = np .transpose (delta_vec )
399+ delta_vec = np .transpose (np .array ([delta ] * (n - j - 1 )))
401400 # The sub-vectors: a_sorted[:, (j+1):]
402401 a_sub = a_sorted [:, (j + 1 ) :]
403402 # After reduction by delta_vec
404403 a_after = a_sub - delta_vec
405404 after_vec [:, (j + 1 ) :] = a_after
406- proj = ( act_multiplier * after_vec ) + (( 1 - act_multiplier ) * proj )
405+ proj += act_multiplier * ( after_vec - proj )
407406 active = active * ind_set
408407 if sum (active ) == 0 :
409- done = True
408+ done = early_done = True
410409 break
411410 j -= 1
411+ if not early_done :
412+ delta = (mat [:, 0 ] + a_sorted [:, 0 ] - eps ) / n
413+ ind_set = np .sign (np .maximum (delta , 0 ))
414+ act_multiplier = ind_set * active
415+ act_multiplier = np .transpose ([np .transpose (act_multiplier )] * n )
416+ delta_vec = np .transpose (np .array ([delta ] * n ))
417+ a_after = a_sorted - delta_vec
418+ proj += act_multiplier * (a_after - proj )
419+ done = True
412420 if not done :
413- proj = active * a_sorted + ( 1 - active ) * proj
421+ proj = active * ( a_sorted - proj )
414422
415423 for i in range (m ):
416424 proj [i , :] = proj [i , a_argsort_inv [i , :]]
@@ -461,7 +469,7 @@ def projection_l1_2(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
461469 mat0 [:, 1 ] = np .min (mat , axis = 1 )
462470 min_t = np .max (mat0 , axis = 1 )
463471 if np .max (min_t ) < 1e-8 :
464- break
472+ continue
465473 row_sums = row_sums - a_var [:, j ] * (n - j )
466474 a_var [:, (j + 1 ) :] = a_var [:, (j + 1 ) :] - np .matmul (min_t .reshape ((m , 1 )), np .ones ((1 , n - j - 1 )))
467475 a_var [:, j ] = a_var [:, j ] - min_t
0 commit comments