Skip to content

Commit 16b46cd

Browse files
authored
Merge pull request #627 from mtsokol/sparse-sort-and-take
API: Add `sort` and `take` functions for COO format
2 parents 82fb0d5 + 5bd29ad commit 16b46cd

File tree

6 files changed

+254
-35
lines changed

6 files changed

+254
-35
lines changed

docs/generated/sparse.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ API
144144

145145
save_npz
146146

147+
sort
148+
147149
squeeze
148150

149151
stack
@@ -152,6 +154,8 @@ API
152154

153155
sum
154156

157+
take
158+
155159
tensordot
156160

157161
tril

sparse/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
nansum,
144144
result_type,
145145
roll,
146+
sort,
147+
take,
146148
tril,
147149
triu,
148150
unique_counts,
@@ -283,13 +285,15 @@
283285
"sign",
284286
"sin",
285287
"sinh",
288+
"sort",
286289
"sqrt",
287290
"square",
288291
"squeeze",
289292
"stack",
290293
"std",
291294
"subtract",
292295
"sum",
296+
"take",
293297
"tan",
294298
"tanh",
295299
"tensordot",

sparse/_coo/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
nansum,
2020
result_type,
2121
roll,
22+
sort,
2223
stack,
24+
take,
2325
tril,
2426
triu,
2527
unique_counts,
@@ -51,7 +53,9 @@
5153
"nansum",
5254
"result_type",
5355
"roll",
56+
"sort",
5457
"stack",
58+
"take",
5559
"tril",
5660
"triu",
5761
"unique_counts",

sparse/_coo/common.py

Lines changed: 178 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from collections.abc import Iterable
44
from functools import reduce
5-
from typing import NamedTuple, Optional, Tuple
5+
from typing import Any, NamedTuple, Optional, Tuple
66

77
import numba
88

@@ -1090,14 +1090,8 @@ def expand_dims(x, /, *, axis=0):
10901090
(1, 6, 1)
10911091
10921092
"""
1093-
from .core import COO
10941093

1095-
if isinstance(x, scipy.sparse.spmatrix):
1096-
x = COO.from_scipy_sparse(x)
1097-
elif not isinstance(x, SparseArray):
1098-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1099-
elif not isinstance(x, COO):
1100-
x = x.asformat(COO)
1094+
x = _validate_coo_input(x)
11011095

11021096
if not isinstance(axis, int):
11031097
raise IndexError(f"Invalid axis position: {axis}")
@@ -1109,6 +1103,8 @@ def expand_dims(x, /, *, axis=0):
11091103
new_shape.insert(axis, 1)
11101104
new_shape = tuple(new_shape)
11111105

1106+
from .core import COO
1107+
11121108
return COO(
11131109
new_coords,
11141110
x.data,
@@ -1140,14 +1136,8 @@ def flip(x, /, *, axis=None):
11401136
relative to ``x``, are reordered.
11411137
11421138
"""
1143-
from .core import COO
11441139

1145-
if isinstance(x, scipy.sparse.spmatrix):
1146-
x = COO.from_scipy_sparse(x)
1147-
elif not isinstance(x, SparseArray):
1148-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1149-
elif not isinstance(x, COO):
1150-
x = x.asformat(COO)
1140+
x = _validate_coo_input(x)
11511141

11521142
if axis is None:
11531143
axis = range(x.ndim)
@@ -1158,6 +1148,8 @@ def flip(x, /, *, axis=None):
11581148
for ax in axis:
11591149
new_coords[ax, :] = x.shape[ax] - 1 - x.coords[ax, :]
11601150

1151+
from .core import COO
1152+
11611153
return COO(
11621154
new_coords,
11631155
x.data,
@@ -1203,14 +1195,8 @@ def unique_counts(x, /):
12031195
>>> sparse.unique_counts(x)
12041196
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
12051197
"""
1206-
from .core import COO
12071198

1208-
if isinstance(x, scipy.sparse.spmatrix):
1209-
x = COO.from_scipy_sparse(x)
1210-
elif not isinstance(x, SparseArray):
1211-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1212-
elif not isinstance(x, COO):
1213-
x = x.asformat(COO)
1199+
x = _validate_coo_input(x)
12141200

12151201
x = x.flatten()
12161202
values, counts = np.unique(x.data, return_counts=True)
@@ -1250,6 +1236,116 @@ def unique_values(x, /):
12501236
>>> sparse.unique_values(x)
12511237
array([-3, 0, 1, 2])
12521238
"""
1239+
1240+
x = _validate_coo_input(x)
1241+
1242+
x = x.flatten()
1243+
values = np.unique(x.data)
1244+
if x.nnz < x.size:
1245+
values = np.sort(np.concatenate([[x.fill_value], values]))
1246+
return values
1247+
1248+
1249+
def sort(x, /, *, axis=-1, descending=False):
1250+
"""
1251+
Returns a sorted copy of an input array ``x``.
1252+
1253+
Parameters
1254+
----------
1255+
x : SparseArray
1256+
Input array. Should have a real-valued data type.
1257+
axis : int
1258+
Axis along which to sort. If set to ``-1``, the function must sort along
1259+
the last axis. Default: ``-1``.
1260+
descending : bool
1261+
Sort order. If ``True``, the array must be sorted in descending order (by value).
1262+
If ``False``, the array must be sorted in ascending order (by value).
1263+
Default: ``False``.
1264+
1265+
Returns
1266+
-------
1267+
out : COO
1268+
A sorted array.
1269+
1270+
Raises
1271+
------
1272+
ValueError
1273+
If the input array isn't and can't be converted to COO format.
1274+
1275+
Examples
1276+
--------
1277+
>>> import sparse
1278+
>>> x = sparse.COO.from_numpy([1, 0, 2, 0, 2, -3])
1279+
>>> sparse.sort(x).todense()
1280+
array([-3, 0, 0, 1, 2, 2])
1281+
>>> sparse.sort(x, descending=True).todense()
1282+
array([ 2, 2, 1, 0, 0, -3])
1283+
1284+
"""
1285+
1286+
from .._common import moveaxis
1287+
from .core import COO
1288+
1289+
x = _validate_coo_input(x)
1290+
1291+
original_ndim = x.ndim
1292+
if x.ndim == 1:
1293+
x = x[None, :]
1294+
axis = -1
1295+
1296+
x = moveaxis(x, source=axis, destination=-1)
1297+
x_shape = x.shape
1298+
x = x.reshape((-1, x_shape[-1]))
1299+
1300+
new_coords, new_data = _sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)
1301+
1302+
x = COO(new_coords, new_data, x.shape, has_duplicates=False, sorted=True, fill_value=x.fill_value)
1303+
1304+
x = x.reshape(x_shape[:-1] + (x_shape[-1],))
1305+
x = moveaxis(x, source=-1, destination=axis)
1306+
1307+
return x if original_ndim == x.ndim else x.squeeze()
1308+
1309+
1310+
def take(x, indices, /, *, axis=None):
1311+
"""
1312+
Returns elements of an array along an axis.
1313+
1314+
Parameters
1315+
----------
1316+
x : SparseArray
1317+
Input array.
1318+
indices : ndarray
1319+
Array indices. The array must be one-dimensional and have an integer data type.
1320+
axis : int
1321+
Axis over which to select values. If ``axis`` is negative, the function must
1322+
determine the axis along which to select values by counting from the last dimension.
1323+
For ``None``, the flattened input array is used. Default: ``None``.
1324+
1325+
Returns
1326+
-------
1327+
out : COO
1328+
A COO array with requested indices.
1329+
1330+
Raises
1331+
------
1332+
ValueError
1333+
If the input array isn't and can't be converted to COO format.
1334+
1335+
"""
1336+
1337+
x = _validate_coo_input(x)
1338+
1339+
if axis is None:
1340+
x = x.flatten()
1341+
return x[indices]
1342+
1343+
axis = normalize_axis(axis, x.ndim)
1344+
full_index = (slice(None),) * axis + (indices, ...)
1345+
return x[full_index]
1346+
1347+
1348+
def _validate_coo_input(x: Any):
12531349
from .core import COO
12541350

12551351
if isinstance(x, scipy.sparse.spmatrix):
@@ -1259,11 +1355,65 @@ def unique_values(x, /):
12591355
elif not isinstance(x, COO):
12601356
x = x.asformat(COO)
12611357

1262-
x = x.flatten()
1263-
values = np.unique(x.data)
1264-
if x.nnz < x.size:
1265-
values = np.sort(np.concatenate([[x.fill_value], values]))
1266-
return values
1358+
return x
1359+
1360+
1361+
@numba.jit(nopython=True, nogil=True)
1362+
def _sort_coo(
1363+
coords: np.ndarray,
1364+
data: np.ndarray,
1365+
fill_value: float,
1366+
sort_axis_len: int,
1367+
descending: bool,
1368+
) -> Tuple[np.ndarray, np.ndarray]:
1369+
assert coords.shape[0] == 2
1370+
group_coords = coords[0, :]
1371+
sort_coords = coords[1, :]
1372+
1373+
data = data.copy()
1374+
result_indices = np.empty_like(sort_coords)
1375+
1376+
# We iterate through all groups and sort each one of them.
1377+
# first and last index of a group is tracked.
1378+
prev_group = -1
1379+
group_first_idx = -1
1380+
group_last_idx = -1
1381+
# We add `-1` sentinel to know when the last group ends
1382+
for idx, group in enumerate(np.append(group_coords, -1)):
1383+
if group == prev_group:
1384+
continue
1385+
1386+
if prev_group != -1:
1387+
group_last_idx = idx
1388+
1389+
group_slice = slice(group_first_idx, group_last_idx)
1390+
group_size = group_last_idx - group_first_idx
1391+
1392+
# SORT VALUES
1393+
if group_size > 1:
1394+
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1395+
# keyword can't be supported.
1396+
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1397+
data[group_slice] = np.sort(data[group_slice])
1398+
if descending:
1399+
data[group_slice] = data[group_slice][::-1]
1400+
1401+
# SORT INDICES
1402+
fill_value_count = sort_axis_len - group_size
1403+
indices = np.arange(group_size)
1404+
# find a place where fill_value would be
1405+
for pos in range(group_size):
1406+
if (not descending and fill_value < data[group_slice][pos]) or (
1407+
descending and fill_value > data[group_slice][pos]
1408+
):
1409+
indices[pos:] += fill_value_count
1410+
break
1411+
result_indices[group_first_idx:group_last_idx] = indices
1412+
1413+
prev_group = group
1414+
group_first_idx = idx
1415+
1416+
return np.vstack((group_coords, result_indices)), data
12671417

12681418

12691419
@numba.jit(nopython=True, nogil=True)
@@ -1323,14 +1473,7 @@ def _arg_minmax_common(
13231473
assert mode in ("max", "min")
13241474
max_mode_flag = mode == "max"
13251475

1326-
from .core import COO
1327-
1328-
if isinstance(x, scipy.sparse.spmatrix):
1329-
x = COO.from_scipy_sparse(x)
1330-
elif not isinstance(x, SparseArray):
1331-
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1332-
elif not isinstance(x, COO):
1333-
x = x.asformat(COO)
1476+
x = _validate_coo_input(x)
13341477

13351478
if not isinstance(axis, (int, type(None))):
13361479
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")

0 commit comments

Comments
 (0)