Skip to content

Commit a79c797

Browse files
authored
Merge pull request #1870 from nmegiddo/patch-5
Update utils.py
2 parents dc29bbc + 9fa7ec4 commit a79c797

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

art/utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,13 +376,13 @@ def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
376376
mat = np.zeros((m, 2))
377377

378378
# if a_sorted[i, n-1] >= a_sorted[i, n-2] + eps, then the projection is [0,...,0,eps]
379-
done = False
379+
done = early_done = False
380380
active = np.array([1] * m)
381381
after_vec = np.zeros((m, n))
382382
proj = a_sorted.copy()
383383
j = n - 2
384384
while j >= 0:
385-
mat[:, 0] = mat[:, 0] + a_sorted[:, j + 1] # = sum(a_sorted[: i] : i = j + 1,...,n-1
385+
mat[:, 0] += a_sorted[:, j + 1] # = sum(a_sorted[: i] : i = j + 1,...,n-1
386386
mat[:, 1] = a_sorted[:, j] * (n - j - 1) + eps
387387
# Find the max in each problem max{ sum{a_sorted[:, i] : i=j+1,..,n-1} , a_sorted[:, j] * (n-j-1) + eps }
388388
row_maxes = np.max(mat, axis=1)
@@ -396,21 +396,29 @@ def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
396396
# has to be reduced is delta
397397
delta = (mat[:, 0] - eps) / (n - j - 1)
398398
# The vector of reductions
399-
delta_vec = np.array([delta] * (n - j - 1))
400-
delta_vec = np.transpose(delta_vec)
399+
delta_vec = np.transpose(np.array([delta] * (n - j - 1)))
401400
# The sub-vectors: a_sorted[:, (j+1):]
402401
a_sub = a_sorted[:, (j + 1) :]
403402
# After reduction by delta_vec
404403
a_after = a_sub - delta_vec
405404
after_vec[:, (j + 1) :] = a_after
406-
proj = (act_multiplier * after_vec) + ((1 - act_multiplier) * proj)
405+
proj += act_multiplier * (after_vec - proj)
407406
active = active * ind_set
408407
if sum(active) == 0:
409-
done = True
408+
done = early_done = True
410409
break
411410
j -= 1
411+
if not early_done:
412+
delta = (mat[:, 0] + a_sorted[:, 0] - eps) / n
413+
ind_set = np.sign(np.maximum(delta, 0))
414+
act_multiplier = ind_set * active
415+
act_multiplier = np.transpose([np.transpose(act_multiplier)] * n)
416+
delta_vec = np.transpose(np.array([delta] * n))
417+
a_after = a_sorted - delta_vec
418+
proj += act_multiplier * (a_after - proj)
419+
done = True
412420
if not done:
413-
proj = active * a_sorted + (1 - active) * proj
421+
proj = active * (a_sorted - proj)
414422

415423
for i in range(m):
416424
proj[i, :] = proj[i, a_argsort_inv[i, :]]
@@ -461,7 +469,7 @@ def projection_l1_2(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
461469
mat0[:, 1] = np.min(mat, axis=1)
462470
min_t = np.max(mat0, axis=1)
463471
if np.max(min_t) < 1e-8:
464-
break
472+
continue
465473
row_sums = row_sums - a_var[:, j] * (n - j)
466474
a_var[:, (j + 1) :] = a_var[:, (j + 1) :] - np.matmul(min_t.reshape((m, 1)), np.ones((1, n - j - 1)))
467475
a_var[:, j] = a_var[:, j] - min_t

0 commit comments

Comments
 (0)