Skip to content

Commit 16cbe86

Browse files
committed
Apply review comments
1 parent 2e2546a commit 16cbe86

File tree

3 files changed

+80
-47
lines changed

3 files changed

+80
-47
lines changed

docs/generated/sparse.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ API
144144

145145
save_npz
146146

147+
sort
148+
147149
squeeze
148150

149151
stack
@@ -152,6 +154,8 @@ API
152154

153155
sum
154156

157+
take
158+
155159
tensordot
156160

157161
tril

sparse/_coo/common.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,14 +1090,8 @@ def expand_dims(x, /, *, axis=0):
10901090
(1, 6, 1)
10911091
10921092
"""
1093-
from .core import COO
10941093

1095-
if isinstance(x, scipy.sparse.spmatrix):
1096-
x = COO.from_scipy_sparse(x)
1097-
elif not isinstance(x, SparseArray):
1098-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1099-
elif not isinstance(x, COO):
1100-
x = x.asformat(COO)
1094+
x = _validate_coo_input(x)
11011095

11021096
if not isinstance(axis, int):
11031097
raise IndexError(f"Invalid axis position: {axis}")
@@ -1109,6 +1103,8 @@ def expand_dims(x, /, *, axis=0):
11091103
new_shape.insert(axis, 1)
11101104
new_shape = tuple(new_shape)
11111105

1106+
from .core import COO
1107+
11121108
return COO(
11131109
new_coords,
11141110
x.data,
@@ -1140,14 +1136,8 @@ def flip(x, /, *, axis=None):
11401136
relative to ``x``, are reordered.
11411137
11421138
"""
1143-
from .core import COO
11441139

1145-
if isinstance(x, scipy.sparse.spmatrix):
1146-
x = COO.from_scipy_sparse(x)
1147-
elif not isinstance(x, SparseArray):
1148-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1149-
elif not isinstance(x, COO):
1150-
x = x.asformat(COO)
1140+
x = _validate_coo_input(x)
11511141

11521142
if axis is None:
11531143
axis = range(x.ndim)
@@ -1158,6 +1148,8 @@ def flip(x, /, *, axis=None):
11581148
for ax in axis:
11591149
new_coords[ax, :] = x.shape[ax] - 1 - x.coords[ax, :]
11601150

1151+
from .core import COO
1152+
11611153
return COO(
11621154
new_coords,
11631155
x.data,
@@ -1291,6 +1283,7 @@ def sort(x, /, *, axis=-1, descending=False):
12911283
12921284
"""
12931285

1286+
from .core import COO
12941287
from .._common import moveaxis
12951288

12961289
x = _validate_coo_input(x)
@@ -1302,9 +1295,13 @@ def sort(x, /, *, axis=-1, descending=False):
13021295

13031296
x = moveaxis(x, source=axis, destination=-1)
13041297
x_shape = x.shape
1305-
x = x.reshape((np.prod(x_shape[:-1]), x_shape[-1]))
1298+
x = x.reshape((-1, x_shape[-1]))
13061299

1307-
_sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)
1300+
new_coords, new_data = _sort_coo(
1301+
x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending
1302+
)
1303+
1304+
x = COO(new_coords, new_data, x.shape, has_duplicates=False, sorted=True, fill_value=x.fill_value)
13081305

13091306
x = x.reshape(x_shape[:-1] + (x_shape[-1],))
13101307
x = moveaxis(x, source=-1, destination=axis)
@@ -1370,42 +1367,55 @@ def _sort_coo(
13701367
fill_value: float,
13711368
sort_axis_len: int,
13721369
descending: bool,
1373-
) -> None:
1370+
) -> Tuple[np.ndarray, np.ndarray]:
13741371
assert coords.shape[0] == 2
13751372
group_coords = coords[0, :]
13761373
sort_coords = coords[1, :]
13771374

1375+
data = data.copy()
13781376
result_indices = np.empty_like(sort_coords)
1379-
offset = 0 # tracks where the current group starts
1380-
1381-
# iterate through all groups and sort each one of them
1382-
for unique_val in np.unique(group_coords):
1383-
# .copy() required by numba, as `reshape` expects a continous array
1384-
group = np.argwhere(group_coords == unique_val).copy()
1385-
group = np.reshape(group, -1)
1386-
group = np.atleast_1d(group)
1387-
1388-
# SORT VALUES
1389-
if group.size > 1:
1390-
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1391-
# keyword can't be supported.
1392-
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1393-
data[group] = np.sort(data[group])
1394-
if descending:
1395-
data[group] = data[group][::-1]
1396-
1397-
# SORT INDICES
1398-
fill_value_count = sort_axis_len - group.size
1399-
indices = np.arange(group.size)
1400-
# find a place where fill_value would be
1401-
for pos in range(group.size):
1402-
if (not descending and fill_value < data[group][pos]) or (descending and fill_value > data[group][pos]):
1403-
indices[pos:] += fill_value_count
1404-
break
1405-
result_indices[offset : offset + len(indices)] = indices
1406-
offset += len(indices)
1407-
1408-
sort_coords[:] = result_indices
1377+
1378+
# We iterate through all groups and sort each one of them.
1379+
# first and last index of a group is tracked.
1380+
prev_group = -1
1381+
group_first_idx = -1
1382+
group_last_idx = -1
1383+
# We add `-1` sentinel to know when the last group ends
1384+
for idx, group in enumerate(np.append(group_coords, -1)):
1385+
if group == prev_group:
1386+
continue
1387+
1388+
if prev_group != -1:
1389+
group_last_idx = idx
1390+
1391+
group_slice = slice(group_first_idx, group_last_idx)
1392+
group_size = group_last_idx - group_first_idx
1393+
1394+
# SORT VALUES
1395+
if group_size > 1:
1396+
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1397+
# keyword can't be supported.
1398+
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1399+
data[group_slice] = np.sort(data[group_slice])
1400+
if descending:
1401+
data[group_slice] = data[group_slice][::-1]
1402+
1403+
# SORT INDICES
1404+
fill_value_count = sort_axis_len - group_size
1405+
indices = np.arange(group_size)
1406+
# find a place where fill_value would be
1407+
for pos in range(group_size):
1408+
if (not descending and fill_value < data[group_slice][pos]) or (
1409+
descending and fill_value > data[group_slice][pos]
1410+
):
1411+
indices[pos:] += fill_value_count
1412+
break
1413+
result_indices[group_first_idx:group_last_idx] = indices
1414+
1415+
prev_group = group
1416+
group_first_idx = idx
1417+
1418+
return np.vstack((group_coords, result_indices)), data
14091419

14101420

14111421
@numba.jit(nopython=True, nogil=True)

sparse/tests/test_coo.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,18 +1795,37 @@ def test_expand_dims(axis):
17951795
[
17961796
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
17971797
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64),
1798+
np.arange(3, 10),
17981799
],
17991800
)
18001801
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
18011802
@pytest.mark.parametrize("axis", [0, 1, -1])
18021803
@pytest.mark.parametrize("descending", [False, True])
18031804
def test_sort(arr, fill_value, axis, descending):
1805+
if axis >= arr.ndim:
1806+
return
1807+
18041808
s_arr = sparse.COO.from_numpy(arr, fill_value)
18051809

18061810
result = sparse.sort(s_arr, axis=axis, descending=descending)
18071811
expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis)
18081812

18091813
np.testing.assert_equal(result.todense(), expected)
1814+
# make sure no inplace changes happened
1815+
np.testing.assert_equal(s_arr.todense(), arr)
1816+
1817+
1818+
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
1819+
@pytest.mark.parametrize("descending", [False, True])
1820+
def test_sort_only_fill_value(fill_value, descending):
1821+
1822+
arr = np.full((3, 3), fill_value=fill_value)
1823+
s_arr = sparse.COO.from_numpy(arr, fill_value)
1824+
1825+
result = sparse.sort(s_arr, axis=0, descending=descending)
1826+
expected = np.sort(arr, axis=0)
1827+
1828+
np.testing.assert_equal(result.todense(), expected)
18101829

18111830

18121831
@pytest.mark.parametrize("axis", [None, -1, 0, 1, 2, (0, 1), (2, 0)])

0 commit comments

Comments
 (0)