Skip to content

Commit c61373d

Browse files
authored
Merge pull request #1586 from nmegiddo/master
Add functions for orthogonal projection on L1 balls
2 parents eddade6 + 183b672 commit c61373d

File tree

1 file changed

+137
-1
lines changed

1 file changed

+137
-1
lines changed

art/utils.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,145 @@ def wrapper(*args, **kwargs):
334334
# ----------------------------------------------------------------------------------------------------- MATH OPERATIONS
335335

336336

337+
def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> np.ndarray:
338+
"""
339+
This function computes the orthogonal projections of a batch of points on L1-balls of given radii
340+
The batch size is m = values.shape[0]. The points are flattened to dimension
341+
n = np.prod(value.shape[1:]). This is required to facilitate sorting.
342+
343+
If a[0] <= ... <= a[n-1], then the projection can be characterized using the largest j such that
344+
a[j+1] +...+ a[n-1] - a[j]*(n-j-1) >= eps. The ith coordinate of projection is equal to 0
345+
if i=0,...,j.
346+
347+
:param values: A batch of m points, each an ndarray
348+
:param eps: The radii of the respective L1-balls
349+
:return: projections
350+
"""
351+
# pylint: disable=C0103
352+
353+
shp = values.shape
354+
a = values.copy()
355+
n = np.prod(a.shape[1:])
356+
m = a.shape[0]
357+
a = a.reshape((m, n))
358+
sgns = np.sign(a)
359+
a = np.abs(a)
360+
361+
a_argsort = a.argsort(axis=1)
362+
a_sorted = np.zeros((m, n))
363+
for i in range(m):
364+
a_sorted[i, :] = a[i, a_argsort[i, :]]
365+
a_argsort_inv = a.argsort(axis=1).argsort(axis=1)
366+
mat = np.zeros((m, 2))
367+
368+
# if a_sorted[i, n-1] >= a_sorted[i, n-2] + eps, then the projection is [0,...,0,eps]
369+
done = False
370+
active = np.array([1] * m)
371+
after_vec = np.zeros((m, n))
372+
proj = a_sorted.copy()
373+
j = n - 2
374+
while j >= 0:
375+
mat[:, 0] = mat[:, 0] + a_sorted[:, j + 1] # = sum(a_sorted[: i] : i = j + 1,...,n-1
376+
mat[:, 1] = a_sorted[:, j] * (n - j - 1) + eps
377+
# Find the max in each problem max{ sum{a_sorted[:, i] : i=j+1,..,n-1} , a_sorted[:, j] * (n-j-1) + eps }
378+
row_maxes = np.max(mat, axis=1)
379+
# Set to 1 if max > a_sorted[:, j] * (n-j-1) + eps > sum ; otherwise, set to 0
380+
ind_set = np.sign(np.sign(row_maxes - mat[:, 0]))
381+
# ind_set = ind_set.reshape((m, 1))
382+
# Multiplier for activation
383+
act_multiplier = (1 - ind_set) * active
384+
act_multiplier = np.transpose([np.transpose(act_multiplier)] * n)
385+
# if done, the projection is supported by the current indices j+1,..,n-1 and the amount by which each
386+
# has to be reduced is delta
387+
delta = (mat[:, 0] - eps) / (n - j - 1)
388+
# The vector of reductions
389+
delta_vec = np.array([delta] * (n - j - 1))
390+
delta_vec = np.transpose(delta_vec)
391+
# The sub-vectors: a_sorted[:, (j+1):]
392+
a_sub = a_sorted[:, (j + 1) :]
393+
# After reduction by delta_vec
394+
a_after = a_sub - delta_vec
395+
after_vec[:, (j + 1) :] = a_after
396+
proj = (act_multiplier * after_vec) + ((1 - act_multiplier) * proj)
397+
active = active * ind_set
398+
if sum(active) == 0:
399+
done = True
400+
break
401+
j -= 1
402+
if not done:
403+
proj = active * a_sorted + (1 - active) * proj
404+
405+
for i in range(m):
406+
proj[i, :] = proj[i, a_argsort_inv[i, :]]
407+
408+
proj = sgns * proj
409+
proj = proj.reshape(shp)
410+
411+
return proj
412+
413+
414+
def projection_l1_2(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> np.ndarray:
415+
"""
416+
This function computes the orthogonal projections of a batch of points on L1-balls of given radii
417+
The batch size is m = values.shape[0]. The points are flattened to dimension
418+
n = np.prod(value.shape[1:]). This is required to facilitate sorting.
419+
420+
Starting from a vector a = (a1,...,an) such that a1 >= ... >= an >= 0, a1 + ... + an > 1,
421+
we first move to a' = a - (t,...,t) such that either a1 + ... + an >= 1 , an >= 0,
422+
and min( a1 + ... + an - nt - 1, an -t ) = 0. This means t = min( (a1 + ... + an - 1)/n, an).
423+
If t = (a1 + ... + an - 1)/n , then a' is the desired projection. Otherwise, the problem is reduced to
424+
finding the projection of (a1 - t, ... , a{n-1} - t ).
425+
426+
:param values: A batch of m points, each an ndarray
427+
:param eps: The radii of the respective L1-balls
428+
:return: projections
429+
"""
430+
# pylint: disable=C0103
431+
shp = values.shape
432+
a = values.copy()
433+
n = np.prod(a.shape[1:])
434+
m = a.shape[0]
435+
a = a.reshape((m, n))
436+
sgns = np.sign(a)
437+
a = np.abs(a)
438+
a_argsort = a.argsort(axis=1)
439+
a_sorted = np.zeros((m, n))
440+
for i in range(m):
441+
a_sorted[i, :] = a[i, a_argsort[i, :]]
442+
443+
a_argsort_inv = a.argsort(axis=1).argsort(axis=1)
444+
row_sums = np.sum(a, axis=1)
445+
mat = np.zeros((m, 2))
446+
mat0 = np.zeros((m, 2))
447+
a_var = a_sorted.copy()
448+
for j in range(n):
449+
mat[:, 0] = (row_sums - eps) / (n - j)
450+
mat[:, 1] = a_var[:, j]
451+
mat0[:, 1] = np.min(mat, axis=1)
452+
min_t = np.max(mat0, axis=1)
453+
if np.max(min_t) < 1e-8:
454+
break
455+
row_sums = row_sums - a_var[:, j] * (n - j)
456+
a_var[:, (j + 1) :] = a_var[:, (j + 1) :] - np.matmul(min_t.reshape((m, 1)), np.ones((1, n - j - 1)))
457+
a_var[:, j] = a_var[:, j] - min_t
458+
proj = np.zeros((m, n))
459+
for i in range(m):
460+
proj[i, :] = a_var[i, a_argsort_inv[i, :]]
461+
462+
proj = sgns * proj
463+
proj = proj.reshape(shp)
464+
return proj
465+
466+
337467
def projection(values: np.ndarray, eps: Union[int, float, np.ndarray], norm_p: Union[int, float, str]) -> np.ndarray:
338468
"""
339469
Project `values` on the L_p norm ball of size `eps`.
340470
341471
:param values: Array of perturbations to clip.
342472
:param eps: Maximum norm allowed.
343-
:param norm_p: L_p norm to use for clipping. Only 1, 2, `np.Inf` and "inf" supported for now.
473+
:param norm_p: L_p norm to use for clipping.
474+
Only 1, 2 , `np.Inf` 1.1 and 1.2 supported for now.
475+
1.1 and 1.2 compute orthogonal projections on l1-ball, using two different algorithms
344476
:return: Values of `values` after projection.
345477
"""
346478
# Pick a small scalar to avoid division by 0
@@ -363,6 +495,10 @@ def projection(values: np.ndarray, eps: Union[int, float, np.ndarray], norm_p: U
363495
np.minimum(1.0, eps / (np.linalg.norm(values_tmp, axis=1, ord=1) + tol)),
364496
axis=1,
365497
)
498+
elif norm_p == 1.1:
499+
values_tmp = projection_l1_1(values_tmp, eps)
500+
elif norm_p == 1.2:
501+
values_tmp = projection_l1_2(values_tmp, eps)
366502

367503
elif norm_p in [np.inf, "inf"]:
368504
if isinstance(eps, np.ndarray):

0 commit comments

Comments
 (0)