36
36
from array_api_extra ._lib ._utils ._compat import (
37
37
device as get_device ,
38
38
)
39
- from array_api_extra ._lib ._utils ._compat import (
40
- is_pydata_sparse_namespace ,
41
- )
42
39
from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
43
40
from array_api_extra ._lib ._utils ._typing import Array , Device
44
41
from array_api_extra .testing import lazy_xp_function
@@ -1344,7 +1341,7 @@ def _assert_valid_partition(
1344
1341
def _partition (cls , x : np .ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
1345
1342
return partition (xp .asarray (x ), k , axis = axis )
1346
1343
1347
- def test_1d (self , xp : ModuleType ):
1344
+ def _test_1d (self , xp : ModuleType ):
1348
1345
rng = np .random .default_rng ()
1349
1346
for n in [2 , 3 , 4 , 5 , 7 , 10 , 20 , 50 , 100 , 1_000 ]:
1350
1347
k = int (rng .integers (n ))
@@ -1355,8 +1352,7 @@ def test_1d(self, xp: ModuleType):
1355
1352
y = self ._partition (x2 , k , xp )
1356
1353
self ._assert_valid_partition (x2 , k , y , xp )
1357
1354
1358
- @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1359
- def test_nd (self , xp : ModuleType , ndim : int ):
1355
+ def _test_nd (self , xp : ModuleType , ndim : int ):
1360
1356
rng = np .random .default_rng ()
1361
1357
1362
1358
for n in [2 , 3 , 5 , 10 , 20 , 100 ]:
@@ -1375,20 +1371,28 @@ def test_nd(self, xp: ModuleType, ndim: int):
1375
1371
y = self ._partition (z , k , xp , axis = None )
1376
1372
self ._assert_valid_partition (z , k , y , xp , axis = None )
1377
1373
1378
- def test_input_validation (self , xp : ModuleType ):
1374
+ def _test_input_validation (self , xp : ModuleType ):
1379
1375
with pytest .raises (TypeError ):
1380
1376
_ = self ._partition (np .asarray (1 ), 1 , xp )
1381
1377
with pytest .raises (ValueError , match = "out of bounds" ):
1382
1378
_ = self ._partition (np .asarray ([1 , 2 ]), 3 , xp )
1383
1379
1380
+ def test_1d (self , xp : ModuleType ):
1381
+ self ._test_1d (xp )
1382
+
1383
+ @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1384
+ def test_nd (self , xp : ModuleType , ndim : int ):
1385
+ self ._test_nd (xp , ndim )
1386
+
1387
+ def test_input_validation (self , xp : ModuleType ):
1388
+ self ._test_input_validation (xp )
1389
+
1384
1390
1385
1391
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
1386
1392
class TestArgpartition (TestPartition ):
1387
1393
@classmethod
1388
1394
@override
1389
1395
def _partition (cls , x : np .ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
1390
- if is_pydata_sparse_namespace (xp ):
1391
- pytest .xfail (reason = "Sparse backend has no argsort" )
1392
1396
arr = xp .asarray (x )
1393
1397
indices = argpartition (arr , k , axis = axis )
1394
1398
if axis is None :
@@ -1398,3 +1402,16 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1
1398
1402
if not hasattr (xp , "take_along_axis" ):
1399
1403
pytest .skip ("TODO: find an alternative to take_along_axis" )
1400
1404
return xp .take_along_axis (arr , indices , axis = axis )
1405
+
1406
+ @override
1407
+ def test_1d (self , xp : ModuleType ):
1408
+ self ._test_1d (xp )
1409
+
1410
+ @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1411
+ @override
1412
+ def test_nd (self , xp : ModuleType , ndim : int ):
1413
+ self ._test_nd (xp , ndim )
1414
+
1415
+ @override
1416
+ def test_input_validation (self , xp : ModuleType ):
1417
+ self ._test_input_validation (xp )
0 commit comments