2
2
import warnings
3
3
from collections .abc import Iterable
4
4
from functools import reduce
5
- from typing import NamedTuple , Optional , Tuple
5
+ from typing import Any , NamedTuple , Optional , Tuple
6
6
7
7
import numba
8
8
@@ -1203,14 +1203,8 @@ def unique_counts(x, /):
1203
1203
>>> sparse.unique_counts(x)
1204
1204
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
1205
1205
"""
1206
- from .core import COO
1207
1206
1208
- if isinstance (x , scipy .sparse .spmatrix ):
1209
- x = COO .from_scipy_sparse (x )
1210
- elif not isinstance (x , SparseArray ):
1211
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1212
- elif not isinstance (x , COO ):
1213
- x = x .asformat (COO )
1207
+ x = _validate_coo_input (x )
1214
1208
1215
1209
x = x .flatten ()
1216
1210
values , counts = np .unique (x .data , return_counts = True )
@@ -1250,6 +1244,113 @@ def unique_values(x, /):
1250
1244
>>> sparse.unique_values(x)
1251
1245
array([-3, 0, 1, 2])
1252
1246
"""
1247
+
1248
+ x = _validate_coo_input (x )
1249
+
1250
+ x = x .flatten ()
1251
+ values = np .unique (x .data )
1252
+ if x .nnz < x .size :
1253
+ values = np .sort (np .concatenate ([[x .fill_value ], values ]))
1254
+ return values
1255
+
1256
+
1257
+ def sort (x , / , * , axis = - 1 , descending = False ):
1258
+ """
1259
+ Returns a sorted copy of an input array ``x``.
1260
+
1261
+ Parameters
1262
+ ----------
1263
+ x : SparseArray
1264
+ Input array. Should have a real-valued data type.
1265
+ axis : int
1266
+ Axis along which to sort. If set to ``-1``, the function must sort along
1267
+ the last axis. Default: ``-1``.
1268
+ descending : bool
1269
+ Sort order. If ``True``, the array must be sorted in descending order (by value).
1270
+ If ``False``, the array must be sorted in ascending order (by value).
1271
+ Default: ``False``.
1272
+
1273
+ Returns
1274
+ -------
1275
+ out : COO
1276
+ A sorted array.
1277
+
1278
+ Raises
1279
+ ------
1280
+ ValueError
1281
+ If the input array isn't and can't be converted to COO format.
1282
+
1283
+ Examples
1284
+ --------
1285
+ >>> import sparse
1286
+ >>> x = sparse.COO.from_numpy([1, 0, 2, 0, 2, -3])
1287
+ >>> sparse.sort(x).todense()
1288
+ array([-3, 0, 0, 1, 2, 2])
1289
+ >>> sparse.sort(x, descending=True).todense()
1290
+ array([ 2, 2, 1, 0, 0, -3])
1291
+
1292
+ """
1293
+
1294
+ from .._common import moveaxis
1295
+
1296
+ x = _validate_coo_input (x )
1297
+
1298
+ original_ndim = x .ndim
1299
+ if x .ndim == 1 :
1300
+ x = x [None , :]
1301
+ axis = - 1
1302
+
1303
+ x = moveaxis (x , source = axis , destination = - 1 )
1304
+ x_shape = x .shape
1305
+ x = x .reshape ((np .prod (x_shape [:- 1 ]), x_shape [- 1 ]))
1306
+
1307
+ _sort_coo (x .coords , x .data , x .fill_value , sort_axis_len = x_shape [- 1 ], descending = descending )
1308
+
1309
+ x = x .reshape (x_shape [:- 1 ] + (x_shape [- 1 ],))
1310
+ x = moveaxis (x , source = - 1 , destination = axis )
1311
+
1312
+ return x if original_ndim == x .ndim else x .squeeze ()
1313
+
1314
+
1315
+ def take (x , indices , / , * , axis = None ):
1316
+ """
1317
+ Returns elements of an array along an axis.
1318
+
1319
+ Parameters
1320
+ ----------
1321
+ x : SparseArray
1322
+ Input array.
1323
+ indices : ndarray
1324
+ Array indices. The array must be one-dimensional and have an integer data type.
1325
+ axis : int
1326
+ Axis over which to select values. If ``axis`` is negative, the function must
1327
+ determine the axis along which to select values by counting from the last dimension.
1328
+ For ``None``, the flattened input array is used. Default: ``None``.
1329
+
1330
+ Returns
1331
+ -------
1332
+ out : COO
1333
+ A COO array with requested indices.
1334
+
1335
+ Raises
1336
+ ------
1337
+ ValueError
1338
+ If the input array isn't and can't be converted to COO format.
1339
+
1340
+ """
1341
+
1342
+ x = _validate_coo_input (x )
1343
+
1344
+ if axis is None :
1345
+ x = x .flatten ()
1346
+ return x [indices ]
1347
+
1348
+ axis = normalize_axis (axis , x .ndim )
1349
+ full_index = (slice (None ),) * axis + (indices , ...)
1350
+ return x [full_index ]
1351
+
1352
+
1353
+ def _validate_coo_input (x : Any ):
1253
1354
from .core import COO
1254
1355
1255
1356
if isinstance (x , scipy .sparse .spmatrix ):
@@ -1259,11 +1360,52 @@ def unique_values(x, /):
1259
1360
elif not isinstance (x , COO ):
1260
1361
x = x .asformat (COO )
1261
1362
1262
- x = x .flatten ()
1263
- values = np .unique (x .data )
1264
- if x .nnz < x .size :
1265
- values = np .sort (np .concatenate ([[x .fill_value ], values ]))
1266
- return values
1363
+ return x
1364
+
1365
+
1366
+ @numba .jit (nopython = True , nogil = True )
1367
+ def _sort_coo (
1368
+ coords : np .ndarray ,
1369
+ data : np .ndarray ,
1370
+ fill_value : float ,
1371
+ sort_axis_len : int ,
1372
+ descending : bool ,
1373
+ ) -> None :
1374
+ assert coords .shape [0 ] == 2
1375
+ group_coords = coords [0 , :]
1376
+ sort_coords = coords [1 , :]
1377
+
1378
+ result_indices = np .empty_like (sort_coords )
1379
+ offset = 0 # tracks where the current group starts
1380
+
1381
+ # iterate through all groups and sort each one of them
1382
+ for unique_val in np .unique (group_coords ):
1383
+ # .copy() required by numba, as `reshape` expects a continous array
1384
+ group = np .argwhere (group_coords == unique_val ).copy ()
1385
+ group = np .reshape (group , - 1 )
1386
+ group = np .atleast_1d (group )
1387
+
1388
+ # SORT VALUES
1389
+ if group .size > 1 :
1390
+ # np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1391
+ # keyword can't be supported.
1392
+ # https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1393
+ data [group ] = np .sort (data [group ])
1394
+ if descending :
1395
+ data [group ] = data [group ][::- 1 ]
1396
+
1397
+ # SORT INDICES
1398
+ fill_value_count = sort_axis_len - group .size
1399
+ indices = np .arange (group .size )
1400
+ # find a place where fill_value would be
1401
+ for pos in range (group .size ):
1402
+ if (not descending and fill_value < data [group ][pos ]) or (descending and fill_value > data [group ][pos ]):
1403
+ indices [pos :] += fill_value_count
1404
+ break
1405
+ result_indices [offset : offset + len (indices )] = indices
1406
+ offset += len (indices )
1407
+
1408
+ sort_coords [:] = result_indices
1267
1409
1268
1410
1269
1411
@numba .jit (nopython = True , nogil = True )
@@ -1323,14 +1465,7 @@ def _arg_minmax_common(
1323
1465
assert mode in ("max" , "min" )
1324
1466
max_mode_flag = mode == "max"
1325
1467
1326
- from .core import COO
1327
-
1328
- if isinstance (x , scipy .sparse .spmatrix ):
1329
- x = COO .from_scipy_sparse (x )
1330
- elif not isinstance (x , SparseArray ):
1331
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1332
- elif not isinstance (x , COO ):
1333
- x = x .asformat (COO )
1468
+ x = _validate_coo_input (x )
1334
1469
1335
1470
if not isinstance (axis , (int , type (None ))):
1336
1471
raise ValueError (f"`axis` must be `int` or `None`, but it's: { type (axis )} ." )
0 commit comments