2
2
import warnings
3
3
from collections .abc import Iterable
4
4
from functools import reduce
5
- from typing import NamedTuple , Optional , Tuple
5
+ from typing import Any , NamedTuple , Optional , Tuple
6
6
7
7
import numba
8
8
@@ -1090,14 +1090,8 @@ def expand_dims(x, /, *, axis=0):
1090
1090
(1, 6, 1)
1091
1091
1092
1092
"""
1093
- from .core import COO
1094
1093
1095
- if isinstance (x , scipy .sparse .spmatrix ):
1096
- x = COO .from_scipy_sparse (x )
1097
- elif not isinstance (x , SparseArray ):
1098
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1099
- elif not isinstance (x , COO ):
1100
- x = x .asformat (COO )
1094
+ x = _validate_coo_input (x )
1101
1095
1102
1096
if not isinstance (axis , int ):
1103
1097
raise IndexError (f"Invalid axis position: { axis } " )
@@ -1109,6 +1103,8 @@ def expand_dims(x, /, *, axis=0):
1109
1103
new_shape .insert (axis , 1 )
1110
1104
new_shape = tuple (new_shape )
1111
1105
1106
+ from .core import COO
1107
+
1112
1108
return COO (
1113
1109
new_coords ,
1114
1110
x .data ,
@@ -1140,14 +1136,8 @@ def flip(x, /, *, axis=None):
1140
1136
relative to ``x``, are reordered.
1141
1137
1142
1138
"""
1143
- from .core import COO
1144
1139
1145
- if isinstance (x , scipy .sparse .spmatrix ):
1146
- x = COO .from_scipy_sparse (x )
1147
- elif not isinstance (x , SparseArray ):
1148
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1149
- elif not isinstance (x , COO ):
1150
- x = x .asformat (COO )
1140
+ x = _validate_coo_input (x )
1151
1141
1152
1142
if axis is None :
1153
1143
axis = range (x .ndim )
@@ -1158,6 +1148,8 @@ def flip(x, /, *, axis=None):
1158
1148
for ax in axis :
1159
1149
new_coords [ax , :] = x .shape [ax ] - 1 - x .coords [ax , :]
1160
1150
1151
+ from .core import COO
1152
+
1161
1153
return COO (
1162
1154
new_coords ,
1163
1155
x .data ,
@@ -1203,14 +1195,8 @@ def unique_counts(x, /):
1203
1195
>>> sparse.unique_counts(x)
1204
1196
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
1205
1197
"""
1206
- from .core import COO
1207
1198
1208
- if isinstance (x , scipy .sparse .spmatrix ):
1209
- x = COO .from_scipy_sparse (x )
1210
- elif not isinstance (x , SparseArray ):
1211
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1212
- elif not isinstance (x , COO ):
1213
- x = x .asformat (COO )
1199
+ x = _validate_coo_input (x )
1214
1200
1215
1201
x = x .flatten ()
1216
1202
values , counts = np .unique (x .data , return_counts = True )
@@ -1250,6 +1236,116 @@ def unique_values(x, /):
1250
1236
>>> sparse.unique_values(x)
1251
1237
array([-3, 0, 1, 2])
1252
1238
"""
1239
+
1240
+ x = _validate_coo_input (x )
1241
+
1242
+ x = x .flatten ()
1243
+ values = np .unique (x .data )
1244
+ if x .nnz < x .size :
1245
+ values = np .sort (np .concatenate ([[x .fill_value ], values ]))
1246
+ return values
1247
+
1248
+
1249
+ def sort (x , / , * , axis = - 1 , descending = False ):
1250
+ """
1251
+ Returns a sorted copy of an input array ``x``.
1252
+
1253
+ Parameters
1254
+ ----------
1255
+ x : SparseArray
1256
+ Input array. Should have a real-valued data type.
1257
+ axis : int
1258
+ Axis along which to sort. If set to ``-1``, the function must sort along
1259
+ the last axis. Default: ``-1``.
1260
+ descending : bool
1261
+ Sort order. If ``True``, the array must be sorted in descending order (by value).
1262
+ If ``False``, the array must be sorted in ascending order (by value).
1263
+ Default: ``False``.
1264
+
1265
+ Returns
1266
+ -------
1267
+ out : COO
1268
+ A sorted array.
1269
+
1270
+ Raises
1271
+ ------
1272
+ ValueError
1273
+ If the input array isn't and can't be converted to COO format.
1274
+
1275
+ Examples
1276
+ --------
1277
+ >>> import sparse
1278
+ >>> x = sparse.COO.from_numpy([1, 0, 2, 0, 2, -3])
1279
+ >>> sparse.sort(x).todense()
1280
+ array([-3, 0, 0, 1, 2, 2])
1281
+ >>> sparse.sort(x, descending=True).todense()
1282
+ array([ 2, 2, 1, 0, 0, -3])
1283
+
1284
+ """
1285
+
1286
+ from .._common import moveaxis
1287
+ from .core import COO
1288
+
1289
+ x = _validate_coo_input (x )
1290
+
1291
+ original_ndim = x .ndim
1292
+ if x .ndim == 1 :
1293
+ x = x [None , :]
1294
+ axis = - 1
1295
+
1296
+ x = moveaxis (x , source = axis , destination = - 1 )
1297
+ x_shape = x .shape
1298
+ x = x .reshape ((- 1 , x_shape [- 1 ]))
1299
+
1300
+ new_coords , new_data = _sort_coo (x .coords , x .data , x .fill_value , sort_axis_len = x_shape [- 1 ], descending = descending )
1301
+
1302
+ x = COO (new_coords , new_data , x .shape , has_duplicates = False , sorted = True , fill_value = x .fill_value )
1303
+
1304
+ x = x .reshape (x_shape [:- 1 ] + (x_shape [- 1 ],))
1305
+ x = moveaxis (x , source = - 1 , destination = axis )
1306
+
1307
+ return x if original_ndim == x .ndim else x .squeeze ()
1308
+
1309
+
1310
+ def take (x , indices , / , * , axis = None ):
1311
+ """
1312
+ Returns elements of an array along an axis.
1313
+
1314
+ Parameters
1315
+ ----------
1316
+ x : SparseArray
1317
+ Input array.
1318
+ indices : ndarray
1319
+ Array indices. The array must be one-dimensional and have an integer data type.
1320
+ axis : int
1321
+ Axis over which to select values. If ``axis`` is negative, the function must
1322
+ determine the axis along which to select values by counting from the last dimension.
1323
+ For ``None``, the flattened input array is used. Default: ``None``.
1324
+
1325
+ Returns
1326
+ -------
1327
+ out : COO
1328
+ A COO array with requested indices.
1329
+
1330
+ Raises
1331
+ ------
1332
+ ValueError
1333
+ If the input array isn't and can't be converted to COO format.
1334
+
1335
+ """
1336
+
1337
+ x = _validate_coo_input (x )
1338
+
1339
+ if axis is None :
1340
+ x = x .flatten ()
1341
+ return x [indices ]
1342
+
1343
+ axis = normalize_axis (axis , x .ndim )
1344
+ full_index = (slice (None ),) * axis + (indices , ...)
1345
+ return x [full_index ]
1346
+
1347
+
1348
+ def _validate_coo_input (x : Any ):
1253
1349
from .core import COO
1254
1350
1255
1351
if isinstance (x , scipy .sparse .spmatrix ):
@@ -1259,11 +1355,65 @@ def unique_values(x, /):
1259
1355
elif not isinstance (x , COO ):
1260
1356
x = x .asformat (COO )
1261
1357
1262
- x = x .flatten ()
1263
- values = np .unique (x .data )
1264
- if x .nnz < x .size :
1265
- values = np .sort (np .concatenate ([[x .fill_value ], values ]))
1266
- return values
1358
+ return x
1359
+
1360
+
1361
+ @numba .jit (nopython = True , nogil = True )
1362
+ def _sort_coo (
1363
+ coords : np .ndarray ,
1364
+ data : np .ndarray ,
1365
+ fill_value : float ,
1366
+ sort_axis_len : int ,
1367
+ descending : bool ,
1368
+ ) -> Tuple [np .ndarray , np .ndarray ]:
1369
+ assert coords .shape [0 ] == 2
1370
+ group_coords = coords [0 , :]
1371
+ sort_coords = coords [1 , :]
1372
+
1373
+ data = data .copy ()
1374
+ result_indices = np .empty_like (sort_coords )
1375
+
1376
+ # We iterate through all groups and sort each one of them.
1377
+ # first and last index of a group is tracked.
1378
+ prev_group = - 1
1379
+ group_first_idx = - 1
1380
+ group_last_idx = - 1
1381
+ # We add `-1` sentinel to know when the last group ends
1382
+ for idx , group in enumerate (np .append (group_coords , - 1 )):
1383
+ if group == prev_group :
1384
+ continue
1385
+
1386
+ if prev_group != - 1 :
1387
+ group_last_idx = idx
1388
+
1389
+ group_slice = slice (group_first_idx , group_last_idx )
1390
+ group_size = group_last_idx - group_first_idx
1391
+
1392
+ # SORT VALUES
1393
+ if group_size > 1 :
1394
+ # np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1395
+ # keyword can't be supported.
1396
+ # https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1397
+ data [group_slice ] = np .sort (data [group_slice ])
1398
+ if descending :
1399
+ data [group_slice ] = data [group_slice ][::- 1 ]
1400
+
1401
+ # SORT INDICES
1402
+ fill_value_count = sort_axis_len - group_size
1403
+ indices = np .arange (group_size )
1404
+ # find a place where fill_value would be
1405
+ for pos in range (group_size ):
1406
+ if (not descending and fill_value < data [group_slice ][pos ]) or (
1407
+ descending and fill_value > data [group_slice ][pos ]
1408
+ ):
1409
+ indices [pos :] += fill_value_count
1410
+ break
1411
+ result_indices [group_first_idx :group_last_idx ] = indices
1412
+
1413
+ prev_group = group
1414
+ group_first_idx = idx
1415
+
1416
+ return np .vstack ((group_coords , result_indices )), data
1267
1417
1268
1418
1269
1419
@numba .jit (nopython = True , nogil = True )
@@ -1323,14 +1473,7 @@ def _arg_minmax_common(
1323
1473
assert mode in ("max" , "min" )
1324
1474
max_mode_flag = mode == "max"
1325
1475
1326
- from .core import COO
1327
-
1328
- if isinstance (x , scipy .sparse .spmatrix ):
1329
- x = COO .from_scipy_sparse (x )
1330
- elif not isinstance (x , SparseArray ):
1331
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1332
- elif not isinstance (x , COO ):
1333
- x = x .asformat (COO )
1476
+ x = _validate_coo_input (x )
1334
1477
1335
1478
if not isinstance (axis , (int , type (None ))):
1336
1479
raise ValueError (f"`axis` must be `int` or `None`, but it's: { type (axis )} ." )
0 commit comments