@@ -334,13 +334,145 @@ def wrapper(*args, **kwargs):
334334# ----------------------------------------------------------------------------------------------------- MATH OPERATIONS
335335
336336
337+ def projection_l1_1 (values : np .ndarray , eps : Union [int , float , np .ndarray ]) -> np .ndarray :
338+ """
339+ This function computes the orthogonal projections of a batch of points on L1-balls of given radii
340+ The batch size is m = values.shape[0]. The points are flattened to dimension
341+ n = np.prod(value.shape[1:]). This is required to facilitate sorting.
342+
343+ If a[0] <= ... <= a[n-1], then the projection can be characterized using the largest j such that
344+ a[j+1] +...+ a[n-1] - a[j]*(n-j-1) >= eps. The ith coordinate of projection is equal to 0
345+ if i=0,...,j.
346+
347+ :param values: A batch of m points, each an ndarray
348+ :param eps: The radii of the respective L1-balls
349+ :return: projections
350+ """
351+ # pylint: disable=C0103
352+
353+ shp = values .shape
354+ a = values .copy ()
355+ n = np .prod (a .shape [1 :])
356+ m = a .shape [0 ]
357+ a = a .reshape ((m , n ))
358+ sgns = np .sign (a )
359+ a = np .abs (a )
360+
361+ a_argsort = a .argsort (axis = 1 )
362+ a_sorted = np .zeros ((m , n ))
363+ for i in range (m ):
364+ a_sorted [i , :] = a [i , a_argsort [i , :]]
365+ a_argsort_inv = a .argsort (axis = 1 ).argsort (axis = 1 )
366+ mat = np .zeros ((m , 2 ))
367+
368+ # if a_sorted[i, n-1] >= a_sorted[i, n-2] + eps, then the projection is [0,...,0,eps]
369+ done = False
370+ active = np .array ([1 ] * m )
371+ after_vec = np .zeros ((m , n ))
372+ proj = a_sorted .copy ()
373+ j = n - 2
374+ while j >= 0 :
375+ mat [:, 0 ] = mat [:, 0 ] + a_sorted [:, j + 1 ] # = sum(a_sorted[: i] : i = j + 1,...,n-1
376+ mat [:, 1 ] = a_sorted [:, j ] * (n - j - 1 ) + eps
377+ # Find the max in each problem max{ sum{a_sorted[:, i] : i=j+1,..,n-1} , a_sorted[:, j] * (n-j-1) + eps }
378+ row_maxes = np .max (mat , axis = 1 )
379+ # Set to 1 if max > a_sorted[:, j] * (n-j-1) + eps > sum ; otherwise, set to 0
380+ ind_set = np .sign (np .sign (row_maxes - mat [:, 0 ]))
381+ # ind_set = ind_set.reshape((m, 1))
382+ # Multiplier for activation
383+ act_multiplier = (1 - ind_set ) * active
384+ act_multiplier = np .transpose ([np .transpose (act_multiplier )] * n )
385+ # if done, the projection is supported by the current indices j+1,..,n-1 and the amount by which each
386+ # has to be reduced is delta
387+ delta = (mat [:, 0 ] - eps ) / (n - j - 1 )
388+ # The vector of reductions
389+ delta_vec = np .array ([delta ] * (n - j - 1 ))
390+ delta_vec = np .transpose (delta_vec )
391+ # The sub-vectors: a_sorted[:, (j+1):]
392+ a_sub = a_sorted [:, (j + 1 ) :]
393+ # After reduction by delta_vec
394+ a_after = a_sub - delta_vec
395+ after_vec [:, (j + 1 ) :] = a_after
396+ proj = (act_multiplier * after_vec ) + ((1 - act_multiplier ) * proj )
397+ active = active * ind_set
398+ if sum (active ) == 0 :
399+ done = True
400+ break
401+ j -= 1
402+ if not done :
403+ proj = active * a_sorted + (1 - active ) * proj
404+
405+ for i in range (m ):
406+ proj [i , :] = proj [i , a_argsort_inv [i , :]]
407+
408+ proj = sgns * proj
409+ proj = proj .reshape (shp )
410+
411+ return proj
412+
413+
414+ def projection_l1_2 (values : np .ndarray , eps : Union [int , float , np .ndarray ]) -> np .ndarray :
415+ """
416+ This function computes the orthogonal projections of a batch of points on L1-balls of given radii
417+ The batch size is m = values.shape[0]. The points are flattened to dimension
418+ n = np.prod(value.shape[1:]). This is required to facilitate sorting.
419+
420+ Starting from a vector a = (a1,...,an) such that a1 >= ... >= an >= 0, a1 + ... + an > 1,
421+ we first move to a' = a - (t,...,t) such that either a1 + ... + an >= 1 , an >= 0,
422+ and min( a1 + ... + an - nt - 1, an -t ) = 0. This means t = min( (a1 + ... + an - 1)/n, an).
423+ If t = (a1 + ... + an - 1)/n , then a' is the desired projection. Otherwise, the problem is reduced to
424+ finding the projection of (a1 - t, ... , a{n-1} - t ).
425+
426+ :param values: A batch of m points, each an ndarray
427+ :param eps: The radii of the respective L1-balls
428+ :return: projections
429+ """
430+ # pylint: disable=C0103
431+ shp = values .shape
432+ a = values .copy ()
433+ n = np .prod (a .shape [1 :])
434+ m = a .shape [0 ]
435+ a = a .reshape ((m , n ))
436+ sgns = np .sign (a )
437+ a = np .abs (a )
438+ a_argsort = a .argsort (axis = 1 )
439+ a_sorted = np .zeros ((m , n ))
440+ for i in range (m ):
441+ a_sorted [i , :] = a [i , a_argsort [i , :]]
442+
443+ a_argsort_inv = a .argsort (axis = 1 ).argsort (axis = 1 )
444+ row_sums = np .sum (a , axis = 1 )
445+ mat = np .zeros ((m , 2 ))
446+ mat0 = np .zeros ((m , 2 ))
447+ a_var = a_sorted .copy ()
448+ for j in range (n ):
449+ mat [:, 0 ] = (row_sums - eps ) / (n - j )
450+ mat [:, 1 ] = a_var [:, j ]
451+ mat0 [:, 1 ] = np .min (mat , axis = 1 )
452+ min_t = np .max (mat0 , axis = 1 )
453+ if np .max (min_t ) < 1e-8 :
454+ break
455+ row_sums = row_sums - a_var [:, j ] * (n - j )
456+ a_var [:, (j + 1 ) :] = a_var [:, (j + 1 ) :] - np .matmul (min_t .reshape ((m , 1 )), np .ones ((1 , n - j - 1 )))
457+ a_var [:, j ] = a_var [:, j ] - min_t
458+ proj = np .zeros ((m , n ))
459+ for i in range (m ):
460+ proj [i , :] = a_var [i , a_argsort_inv [i , :]]
461+
462+ proj = sgns * proj
463+ proj = proj .reshape (shp )
464+ return proj
465+
466+
337467def projection (values : np .ndarray , eps : Union [int , float , np .ndarray ], norm_p : Union [int , float , str ]) -> np .ndarray :
338468 """
339469 Project `values` on the L_p norm ball of size `eps`.
340470
341471 :param values: Array of perturbations to clip.
342472 :param eps: Maximum norm allowed.
343- :param norm_p: L_p norm to use for clipping. Only 1, 2, `np.Inf` and "inf" supported for now.
473+ :param norm_p: L_p norm to use for clipping.
474+ Only 1, 2 , `np.Inf` 1.1 and 1.2 supported for now.
475+ 1.1 and 1.2 compute orthogonal projections on l1-ball, using two different algorithms
344476 :return: Values of `values` after projection.
345477 """
346478 # Pick a small scalar to avoid division by 0
@@ -363,6 +495,10 @@ def projection(values: np.ndarray, eps: Union[int, float, np.ndarray], norm_p: U
363495 np .minimum (1.0 , eps / (np .linalg .norm (values_tmp , axis = 1 , ord = 1 ) + tol )),
364496 axis = 1 ,
365497 )
498+ elif norm_p == 1.1 :
499+ values_tmp = projection_l1_1 (values_tmp , eps )
500+ elif norm_p == 1.2 :
501+ values_tmp = projection_l1_2 (values_tmp , eps )
366502
367503 elif norm_p in [np .inf , "inf" ]:
368504 if isinstance (eps , np .ndarray ):
0 commit comments