Skip to content

Commit d85ec73

Browse files
authored
BUG: array_api.argsort(descending=True) respects relative sort order (#20788)
* BUG: `array_api.argsort(descending=True)` respects relative order * Regression test for stable descending `array_api.argsort()` Original NumPy Commit: d7a43dfa91cc1363db64da8915db2b4b6c847b81
1 parent 37bf553 commit d85ec73

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

array_api_strict/_sorting_functions.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,20 @@ def argsort(
1515
"""
1616
# Note: this keyword argument is different, and the default is different.
1717
kind = "stable" if stable else "quicksort"
18-
res = np.argsort(x._array, axis=axis, kind=kind)
19-
if descending:
20-
res = np.flip(res, axis=axis)
18+
if not descending:
19+
res = np.argsort(x._array, axis=axis, kind=kind)
20+
else:
21+
# As NumPy has no native descending sort, we imitate it here. Note that
22+
# simply flipping the results of np.argsort(x._array, ...) would not
23+
# respect the relative order like it would in native descending sorts.
24+
res = np.flip(
25+
np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind),
26+
axis=axis,
27+
)
28+
# Rely on flip()/argsort() to validate axis
29+
normalised_axis = axis if axis >= 0 else x.ndim + axis
30+
max_i = x.shape[normalised_axis] - 1
31+
res = max_i - res
2132
return Array._new(res)
2233

2334

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
import array_api_strict as xp
4+
5+
6+
@pytest.mark.parametrize(
7+
"obj, axis, expected",
8+
[
9+
([0, 0], -1, [0, 1]),
10+
([0, 1, 0], -1, [1, 0, 2]),
11+
([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
12+
([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
13+
],
14+
)
15+
def test_stable_desc_argsort(obj, axis, expected):
16+
"""
17+
Indices respect relative order of a descending stable-sort
18+
19+
See https://github.com/numpy/numpy/issues/20778
20+
"""
21+
x = xp.asarray(obj)
22+
out = xp.argsort(x, axis=axis, stable=True, descending=True)
23+
assert xp.all(out == xp.asarray(expected))

0 commit comments

Comments
 (0)