Skip to content

Commit 907886f

Browse files
authored
Update utils.py
Fixes some bugs in the functions projection_l1_1 projection_l1_2
1 parent 89bf92f commit 907886f

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

art/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,13 +376,13 @@ def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
376376
mat = np.zeros((m, 2))
377377

378378
# if a_sorted[i, n-1] >= a_sorted[i, n-2] + eps, then the projection is [0,...,0,eps]
379-
done = False
379+
done = early_done = False
380380
active = np.array([1] * m)
381381
after_vec = np.zeros((m, n))
382382
proj = a_sorted.copy()
383383
j = n - 2
384384
while j >= 0:
385-
mat[:, 0] = mat[:, 0] + a_sorted[:, j + 1] # = sum(a_sorted[: i] : i = j + 1,...,n-1
385+
mat[:, 0] += a_sorted[:, j + 1] # = sum(a_sorted[: i] : i = j + 1,...,n-1
386386
mat[:, 1] = a_sorted[:, j] * (n - j - 1) + eps
387387
# Find the max in each problem max{ sum{a_sorted[:, i] : i=j+1,..,n-1} , a_sorted[:, j] * (n-j-1) + eps }
388388
row_maxes = np.max(mat, axis=1)
@@ -396,21 +396,31 @@ def projection_l1_1(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
396396
# has to be reduced is delta
397397
delta = (mat[:, 0] - eps) / (n - j - 1)
398398
# The vector of reductions
399-
delta_vec = np.array([delta] * (n - j - 1))
400-
delta_vec = np.transpose(delta_vec)
399+
delta_vec = np.transpose(np.array([delta] * (n - j - 1)))
401400
# The sub-vectors: a_sorted[:, (j+1):]
402401
a_sub = a_sorted[:, (j + 1) :]
403402
# After reduction by delta_vec
404403
a_after = a_sub - delta_vec
405404
after_vec[:, (j + 1) :] = a_after
406-
proj = (act_multiplier * after_vec) + ((1 - act_multiplier) * proj)
405+
proj += act_multiplier * (after_vec - proj)
407406
active = active * ind_set
408407
if sum(active) == 0:
409-
done = True
408+
done = early_done = True
410409
break
411410
j -= 1
411+
412+
if not early_done:
413+
delta = (mat[:, 0] + a_sorted[:, 0] - eps) / n
414+
ind_set = np.sign(np.maximum(delta, 0))
415+
act_multiplier = ind_set * active
416+
act_multiplier = np.transpose([np.transpose(act_multiplier)] * n)
417+
delta_vec = np.transpose(np.array([delta] * n))
418+
a_after = a_sorted - delta_vec
419+
proj += act_multiplier * (a_after - proj)
420+
done = True
421+
412422
if not done:
413-
proj = active * a_sorted + (1 - active) * proj
423+
proj = active * (a_sorted - proj)
414424

415425
for i in range(m):
416426
proj[i, :] = proj[i, a_argsort_inv[i, :]]
@@ -461,7 +471,7 @@ def projection_l1_2(values: np.ndarray, eps: Union[int, float, np.ndarray]) -> n
461471
mat0[:, 1] = np.min(mat, axis=1)
462472
min_t = np.max(mat0, axis=1)
463473
if np.max(min_t) < 1e-8:
464-
break
474+
continue
465475
row_sums = row_sums - a_var[:, j] * (n - j)
466476
a_var[:, (j + 1) :] = a_var[:, (j + 1) :] - np.matmul(min_t.reshape((m, 1)), np.ones((1, n - j - 1)))
467477
a_var[:, j] = a_var[:, j] - min_t

0 commit comments

Comments
 (0)