@@ -1090,14 +1090,8 @@ def expand_dims(x, /, *, axis=0):
1090
1090
(1, 6, 1)
1091
1091
1092
1092
"""
1093
- from .core import COO
1094
1093
1095
- if isinstance (x , scipy .sparse .spmatrix ):
1096
- x = COO .from_scipy_sparse (x )
1097
- elif not isinstance (x , SparseArray ):
1098
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1099
- elif not isinstance (x , COO ):
1100
- x = x .asformat (COO )
1094
+ x = _validate_coo_input (x )
1101
1095
1102
1096
if not isinstance (axis , int ):
1103
1097
raise IndexError (f"Invalid axis position: { axis } " )
@@ -1109,6 +1103,8 @@ def expand_dims(x, /, *, axis=0):
1109
1103
new_shape .insert (axis , 1 )
1110
1104
new_shape = tuple (new_shape )
1111
1105
1106
+ from .core import COO
1107
+
1112
1108
return COO (
1113
1109
new_coords ,
1114
1110
x .data ,
@@ -1140,14 +1136,8 @@ def flip(x, /, *, axis=None):
1140
1136
relative to ``x``, are reordered.
1141
1137
1142
1138
"""
1143
- from .core import COO
1144
1139
1145
- if isinstance (x , scipy .sparse .spmatrix ):
1146
- x = COO .from_scipy_sparse (x )
1147
- elif not isinstance (x , SparseArray ):
1148
- raise ValueError (f"Input must be an instance of SparseArray, but it's { type (x )} ." )
1149
- elif not isinstance (x , COO ):
1150
- x = x .asformat (COO )
1140
+ x = _validate_coo_input (x )
1151
1141
1152
1142
if axis is None :
1153
1143
axis = range (x .ndim )
@@ -1158,6 +1148,8 @@ def flip(x, /, *, axis=None):
1158
1148
for ax in axis :
1159
1149
new_coords [ax , :] = x .shape [ax ] - 1 - x .coords [ax , :]
1160
1150
1151
+ from .core import COO
1152
+
1161
1153
return COO (
1162
1154
new_coords ,
1163
1155
x .data ,
@@ -1291,6 +1283,7 @@ def sort(x, /, *, axis=-1, descending=False):
1291
1283
1292
1284
"""
1293
1285
1286
+ from .core import COO
1294
1287
from .._common import moveaxis
1295
1288
1296
1289
x = _validate_coo_input (x )
@@ -1302,9 +1295,13 @@ def sort(x, /, *, axis=-1, descending=False):
1302
1295
1303
1296
x = moveaxis (x , source = axis , destination = - 1 )
1304
1297
x_shape = x .shape
1305
- x = x .reshape ((np . prod ( x_shape [: - 1 ]) , x_shape [- 1 ]))
1298
+ x = x .reshape ((- 1 , x_shape [- 1 ]))
1306
1299
1307
- _sort_coo (x .coords , x .data , x .fill_value , sort_axis_len = x_shape [- 1 ], descending = descending )
1300
+ new_coords , new_data = _sort_coo (
1301
+ x .coords , x .data , x .fill_value , sort_axis_len = x_shape [- 1 ], descending = descending
1302
+ )
1303
+
1304
+ x = COO (new_coords , new_data , x .shape , has_duplicates = False , sorted = True , fill_value = x .fill_value )
1308
1305
1309
1306
x = x .reshape (x_shape [:- 1 ] + (x_shape [- 1 ],))
1310
1307
x = moveaxis (x , source = - 1 , destination = axis )
@@ -1370,42 +1367,55 @@ def _sort_coo(
1370
1367
fill_value : float ,
1371
1368
sort_axis_len : int ,
1372
1369
descending : bool ,
1373
- ) -> None :
1370
+ ) -> Tuple [ np . ndarray , np . ndarray ] :
1374
1371
assert coords .shape [0 ] == 2
1375
1372
group_coords = coords [0 , :]
1376
1373
sort_coords = coords [1 , :]
1377
1374
1375
+ data = data .copy ()
1378
1376
result_indices = np .empty_like (sort_coords )
1379
- offset = 0 # tracks where the current group starts
1380
-
1381
- # iterate through all groups and sort each one of them
1382
- for unique_val in np .unique (group_coords ):
1383
- # .copy() required by numba, as `reshape` expects a continous array
1384
- group = np .argwhere (group_coords == unique_val ).copy ()
1385
- group = np .reshape (group , - 1 )
1386
- group = np .atleast_1d (group )
1387
-
1388
- # SORT VALUES
1389
- if group .size > 1 :
1390
- # np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1391
- # keyword can't be supported.
1392
- # https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1393
- data [group ] = np .sort (data [group ])
1394
- if descending :
1395
- data [group ] = data [group ][::- 1 ]
1396
-
1397
- # SORT INDICES
1398
- fill_value_count = sort_axis_len - group .size
1399
- indices = np .arange (group .size )
1400
- # find a place where fill_value would be
1401
- for pos in range (group .size ):
1402
- if (not descending and fill_value < data [group ][pos ]) or (descending and fill_value > data [group ][pos ]):
1403
- indices [pos :] += fill_value_count
1404
- break
1405
- result_indices [offset : offset + len (indices )] = indices
1406
- offset += len (indices )
1407
-
1408
- sort_coords [:] = result_indices
1377
+
1378
+ # We iterate through all groups and sort each one of them.
1379
+ # first and last index of a group is tracked.
1380
+ prev_group = - 1
1381
+ group_first_idx = - 1
1382
+ group_last_idx = - 1
1383
+ # We add `-1` sentinel to know when the last group ends
1384
+ for idx , group in enumerate (np .append (group_coords , - 1 )):
1385
+ if group == prev_group :
1386
+ continue
1387
+
1388
+ if prev_group != - 1 :
1389
+ group_last_idx = idx
1390
+
1391
+ group_slice = slice (group_first_idx , group_last_idx )
1392
+ group_size = group_last_idx - group_first_idx
1393
+
1394
+ # SORT VALUES
1395
+ if group_size > 1 :
1396
+ # np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1397
+ # keyword can't be supported.
1398
+ # https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1399
+ data [group_slice ] = np .sort (data [group_slice ])
1400
+ if descending :
1401
+ data [group_slice ] = data [group_slice ][::- 1 ]
1402
+
1403
+ # SORT INDICES
1404
+ fill_value_count = sort_axis_len - group_size
1405
+ indices = np .arange (group_size )
1406
+ # find a place where fill_value would be
1407
+ for pos in range (group_size ):
1408
+ if (not descending and fill_value < data [group_slice ][pos ]) or (
1409
+ descending and fill_value > data [group_slice ][pos ]
1410
+ ):
1411
+ indices [pos :] += fill_value_count
1412
+ break
1413
+ result_indices [group_first_idx :group_last_idx ] = indices
1414
+
1415
+ prev_group = group
1416
+ group_first_idx = idx
1417
+
1418
+ return np .vstack ((group_coords , result_indices )), data
1409
1419
1410
1420
1411
1421
@numba .jit (nopython = True , nogil = True )
0 commit comments