Skip to content

Commit 2e2546a

Browse files
committed
API: Add sort and take functions for COO format
1 parent 82fb0d5 commit 2e2546a

File tree

5 files changed

+210
-21
lines changed

5 files changed

+210
-21
lines changed

sparse/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
nansum,
144144
result_type,
145145
roll,
146+
sort,
147+
take,
146148
tril,
147149
triu,
148150
unique_counts,
@@ -283,13 +285,15 @@
283285
"sign",
284286
"sin",
285287
"sinh",
288+
"sort",
286289
"sqrt",
287290
"square",
288291
"squeeze",
289292
"stack",
290293
"std",
291294
"subtract",
292295
"sum",
296+
"take",
293297
"tan",
294298
"tanh",
295299
"tensordot",

sparse/_coo/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
nansum,
2020
result_type,
2121
roll,
22+
sort,
2223
stack,
24+
take,
2325
tril,
2426
triu,
2527
unique_counts,
@@ -51,7 +53,9 @@
5153
"nansum",
5254
"result_type",
5355
"roll",
56+
"sort",
5457
"stack",
58+
"take",
5559
"tril",
5660
"triu",
5761
"unique_counts",

sparse/_coo/common.py

Lines changed: 156 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from collections.abc import Iterable
44
from functools import reduce
5-
from typing import NamedTuple, Optional, Tuple
5+
from typing import Any, NamedTuple, Optional, Tuple
66

77
import numba
88

@@ -1203,14 +1203,8 @@ def unique_counts(x, /):
12031203
>>> sparse.unique_counts(x)
12041204
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
12051205
"""
1206-
from .core import COO
12071206

1208-
if isinstance(x, scipy.sparse.spmatrix):
1209-
x = COO.from_scipy_sparse(x)
1210-
elif not isinstance(x, SparseArray):
1211-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1212-
elif not isinstance(x, COO):
1213-
x = x.asformat(COO)
1207+
x = _validate_coo_input(x)
12141208

12151209
x = x.flatten()
12161210
values, counts = np.unique(x.data, return_counts=True)
@@ -1250,6 +1244,113 @@ def unique_values(x, /):
12501244
>>> sparse.unique_values(x)
12511245
array([-3, 0, 1, 2])
12521246
"""
1247+
1248+
x = _validate_coo_input(x)
1249+
1250+
x = x.flatten()
1251+
values = np.unique(x.data)
1252+
if x.nnz < x.size:
1253+
values = np.sort(np.concatenate([[x.fill_value], values]))
1254+
return values
1255+
1256+
1257+
def sort(x, /, *, axis=-1, descending=False):
1258+
"""
1259+
Returns a sorted copy of an input array ``x``.
1260+
1261+
Parameters
1262+
----------
1263+
x : SparseArray
1264+
Input array. Should have a real-valued data type.
1265+
axis : int
1266+
Axis along which to sort. If set to ``-1``, the function must sort along
1267+
the last axis. Default: ``-1``.
1268+
descending : bool
1269+
Sort order. If ``True``, the array must be sorted in descending order (by value).
1270+
If ``False``, the array must be sorted in ascending order (by value).
1271+
Default: ``False``.
1272+
1273+
Returns
1274+
-------
1275+
out : COO
1276+
A sorted array.
1277+
1278+
Raises
1279+
------
1280+
ValueError
1281+
If the input array isn't and can't be converted to COO format.
1282+
1283+
Examples
1284+
--------
1285+
>>> import sparse
1286+
>>> x = sparse.COO.from_numpy([1, 0, 2, 0, 2, -3])
1287+
>>> sparse.sort(x).todense()
1288+
array([-3, 0, 0, 1, 2, 2])
1289+
>>> sparse.sort(x, descending=True).todense()
1290+
array([ 2, 2, 1, 0, 0, -3])
1291+
1292+
"""
1293+
1294+
from .._common import moveaxis
1295+
1296+
x = _validate_coo_input(x)
1297+
1298+
original_ndim = x.ndim
1299+
if x.ndim == 1:
1300+
x = x[None, :]
1301+
axis = -1
1302+
1303+
x = moveaxis(x, source=axis, destination=-1)
1304+
x_shape = x.shape
1305+
x = x.reshape((np.prod(x_shape[:-1]), x_shape[-1]))
1306+
1307+
_sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)
1308+
1309+
x = x.reshape(x_shape[:-1] + (x_shape[-1],))
1310+
x = moveaxis(x, source=-1, destination=axis)
1311+
1312+
return x if original_ndim == x.ndim else x.squeeze()
1313+
1314+
1315+
def take(x, indices, /, *, axis=None):
1316+
"""
1317+
Returns elements of an array along an axis.
1318+
1319+
Parameters
1320+
----------
1321+
x : SparseArray
1322+
Input array.
1323+
indices : ndarray
1324+
Array indices. The array must be one-dimensional and have an integer data type.
1325+
axis : int
1326+
Axis over which to select values. If ``axis`` is negative, the function must
1327+
determine the axis along which to select values by counting from the last dimension.
1328+
For ``None``, the flattened input array is used. Default: ``None``.
1329+
1330+
Returns
1331+
-------
1332+
out : COO
1333+
A COO array with requested indices.
1334+
1335+
Raises
1336+
------
1337+
ValueError
1338+
If the input array isn't and can't be converted to COO format.
1339+
1340+
"""
1341+
1342+
x = _validate_coo_input(x)
1343+
1344+
if axis is None:
1345+
x = x.flatten()
1346+
return x[indices]
1347+
1348+
axis = normalize_axis(axis, x.ndim)
1349+
full_index = (slice(None),) * axis + (indices, ...)
1350+
return x[full_index]
1351+
1352+
1353+
def _validate_coo_input(x: Any):
12531354
from .core import COO
12541355

12551356
if isinstance(x, scipy.sparse.spmatrix):
@@ -1259,11 +1360,52 @@ def unique_values(x, /):
12591360
elif not isinstance(x, COO):
12601361
x = x.asformat(COO)
12611362

1262-
x = x.flatten()
1263-
values = np.unique(x.data)
1264-
if x.nnz < x.size:
1265-
values = np.sort(np.concatenate([[x.fill_value], values]))
1266-
return values
1363+
return x
1364+
1365+
1366+
@numba.jit(nopython=True, nogil=True)
1367+
def _sort_coo(
1368+
coords: np.ndarray,
1369+
data: np.ndarray,
1370+
fill_value: float,
1371+
sort_axis_len: int,
1372+
descending: bool,
1373+
) -> None:
1374+
assert coords.shape[0] == 2
1375+
group_coords = coords[0, :]
1376+
sort_coords = coords[1, :]
1377+
1378+
result_indices = np.empty_like(sort_coords)
1379+
offset = 0 # tracks where the current group starts
1380+
1381+
# iterate through all groups and sort each one of them
1382+
for unique_val in np.unique(group_coords):
1383+
# .copy() required by numba, as `reshape` expects a continous array
1384+
group = np.argwhere(group_coords == unique_val).copy()
1385+
group = np.reshape(group, -1)
1386+
group = np.atleast_1d(group)
1387+
1388+
# SORT VALUES
1389+
if group.size > 1:
1390+
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1391+
# keyword can't be supported.
1392+
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1393+
data[group] = np.sort(data[group])
1394+
if descending:
1395+
data[group] = data[group][::-1]
1396+
1397+
# SORT INDICES
1398+
fill_value_count = sort_axis_len - group.size
1399+
indices = np.arange(group.size)
1400+
# find a place where fill_value would be
1401+
for pos in range(group.size):
1402+
if (not descending and fill_value < data[group][pos]) or (descending and fill_value > data[group][pos]):
1403+
indices[pos:] += fill_value_count
1404+
break
1405+
result_indices[offset : offset + len(indices)] = indices
1406+
offset += len(indices)
1407+
1408+
sort_coords[:] = result_indices
12671409

12681410

12691411
@numba.jit(nopython=True, nogil=True)
@@ -1323,14 +1465,7 @@ def _arg_minmax_common(
13231465
assert mode in ("max", "min")
13241466
max_mode_flag = mode == "max"
13251467

1326-
from .core import COO
1327-
1328-
if isinstance(x, scipy.sparse.spmatrix):
1329-
x = COO.from_scipy_sparse(x)
1330-
elif not isinstance(x, SparseArray):
1331-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1332-
elif not isinstance(x, COO):
1333-
x = x.asformat(COO)
1468+
x = _validate_coo_input(x)
13341469

13351470
if not isinstance(axis, (int, type(None))):
13361471
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")

sparse/tests/test_coo.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,25 @@ def test_expand_dims(axis):
17901790
np.testing.assert_equal(result.todense(), expected)
17911791

17921792

1793+
@pytest.mark.parametrize(
1794+
"arr",
1795+
[
1796+
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
1797+
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64),
1798+
],
1799+
)
1800+
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
1801+
@pytest.mark.parametrize("axis", [0, 1, -1])
1802+
@pytest.mark.parametrize("descending", [False, True])
1803+
def test_sort(arr, fill_value, axis, descending):
1804+
s_arr = sparse.COO.from_numpy(arr, fill_value)
1805+
1806+
result = sparse.sort(s_arr, axis=axis, descending=descending)
1807+
expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis)
1808+
1809+
np.testing.assert_equal(result.todense(), expected)
1810+
1811+
17931812
@pytest.mark.parametrize("axis", [None, -1, 0, 1, 2, (0, 1), (2, 0)])
17941813
def test_flip(axis):
17951814
arr = np.arange(24).reshape((2, 3, 4))
@@ -1799,3 +1818,28 @@ def test_flip(axis):
17991818
expected = np.flip(arr, axis=axis)
18001819

18011820
np.testing.assert_equal(result.todense(), expected)
1821+
1822+
1823+
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
1824+
@pytest.mark.parametrize(
1825+
"indices,axis",
1826+
[
1827+
(
1828+
[1],
1829+
0,
1830+
),
1831+
([2, 1], 1),
1832+
([1, 2, 3], 2),
1833+
([2, 3], -1),
1834+
([5, 3, 7, 8], None),
1835+
],
1836+
)
1837+
def test_take(fill_value, indices, axis):
1838+
arr = np.arange(24).reshape((2, 3, 4))
1839+
1840+
s_arr = sparse.COO.from_numpy(arr, fill_value)
1841+
1842+
result = sparse.take(s_arr, np.array(indices), axis=axis)
1843+
expected = np.take(arr, indices, axis)
1844+
1845+
np.testing.assert_equal(result.todense(), expected)

sparse/tests/test_namespace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,15 @@ def test_namespace():
130130
"sign",
131131
"sin",
132132
"sinh",
133+
"sort",
133134
"sqrt",
134135
"square",
135136
"squeeze",
136137
"stack",
137138
"std",
138139
"subtract",
139140
"sum",
141+
"take",
140142
"tan",
141143
"tanh",
142144
"tensordot",

0 commit comments

Comments
 (0)