@@ -1310,41 +1310,52 @@ def test_xp(self, xp: ModuleType):
1310
1310
1311
1311
class TestPartition :
1312
1312
@classmethod
1313
- def _assert_valid_partition (cls , x : Array , k : int , xp : ModuleType , axis : int = - 1 ):
1314
- if x .ndim != 1 and axis == 0 :
1315
- assert isinstance (x .shape [1 ], int )
1316
- for i in range (x .shape [1 ]):
1317
- cls ._assert_valid_partition (x [:, i , ...], k , xp , axis = 0 )
1318
- elif x .ndim != 1 :
1313
+ def _assert_valid_partition (
1314
+ cls ,
1315
+ x_np : np .ndarray | None ,
1316
+ k : int ,
1317
+ y : Array ,
1318
+ xp : ModuleType ,
1319
+ axis : int | None = - 1 ,
1320
+ ):
1321
+ """
1322
+ x : input array
1323
+ k : int
1324
+ y : output array returned by the partition function to test
1325
+ """
1326
+ if x_np is not None :
1327
+ assert y .shape == np .partition (x_np , k , axis = axis ).shape
1328
+ if y .ndim != 1 and axis == 0 :
1329
+ assert isinstance (y .shape [1 ], int )
1330
+ for i in range (y .shape [1 ]):
1331
+ cls ._assert_valid_partition (None , k , y [:, i , ...], xp , axis = 0 )
1332
+ elif y .ndim != 1 :
1333
+ assert axis is not None
1319
1334
axis = axis - 1 if axis != - 1 else - 1
1320
- assert isinstance (x .shape [0 ], int )
1321
- for i in range (x .shape [0 ]):
1322
- cls ._assert_valid_partition (x [i , ...], k , xp , axis = axis )
1335
+ assert isinstance (y .shape [0 ], int )
1336
+ for i in range (y .shape [0 ]):
1337
+ cls ._assert_valid_partition (None , k , y [i , ...], xp , axis = axis )
1323
1338
else :
1324
1339
if k > 0 :
1325
- assert xp .max (x [:k ]) <= x [k ]
1326
- assert x [k ] <= xp .min (x [k :])
1340
+ assert xp .max (y [:k ]) <= y [k ]
1341
+ assert y [k ] <= xp .min (y [k :])
1327
1342
1328
1343
@classmethod
1329
- def _partition (
1330
- cls ,
1331
- x : Array ,
1332
- k : int ,
1333
- xp : ModuleType , # noqa: ARG003
1334
- axis : int | None = - 1 ,
1335
- ):
1336
- return partition (x , k , axis = axis )
1344
+ def _partition (cls , x : np .ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
1345
+ return partition (xp .asarray (x ), k , axis = axis )
1337
1346
1338
1347
def test_1d (self , xp : ModuleType ):
1339
1348
rng = np .random .default_rng ()
1340
1349
for n in [2 , 3 , 4 , 5 , 7 , 10 , 20 , 50 , 100 , 1_000 ]:
1341
1350
k = int (rng .integers (n ))
1342
- x = xp .asarray (rng .integers (n , size = n ))
1343
- self ._assert_valid_partition (self ._partition (x , k , xp ), k , xp )
1344
- x = xp .asarray (rng .random (n ))
1345
- self ._assert_valid_partition (self ._partition (x , k , xp ), k , xp )
1346
-
1347
- @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 , 5 ])
1351
+ x1 = rng .integers (n , size = n )
1352
+ y = self ._partition (x1 , k , xp )
1353
+ self ._assert_valid_partition (x1 , k , y , xp )
1354
+ x2 = rng .random (n )
1355
+ y = self ._partition (x2 , k , xp )
1356
+ self ._assert_valid_partition (x2 , k , y , xp )
1357
+
1358
+ @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1348
1359
def test_nd (self , xp : ModuleType , ndim : int ):
1349
1360
rng = np .random .default_rng ()
1350
1361
@@ -1355,27 +1366,35 @@ def test_nd(self, xp: ModuleType, ndim: int):
1355
1366
for i in range (ndim ):
1356
1367
shape = base_shape [:]
1357
1368
shape [i ] = n
1358
- x = xp . asarray ( rng .integers (n , size = tuple (shape ) ))
1369
+ x = rng .integers (n , size = tuple (shape ))
1359
1370
y = self ._partition (x , k , xp , axis = i )
1360
- self ._assert_valid_partition (y , k , xp , axis = i )
1371
+ self ._assert_valid_partition (x , k , y , xp , axis = i )
1372
+
1373
+ z = rng .random (tuple (base_shape ))
1374
+ k = int (rng .integers (z .size ))
1375
+ y = self ._partition (z , k , xp , axis = None )
1376
+ self ._assert_valid_partition (z , k , y , xp , axis = None )
1361
1377
1362
1378
def test_input_validation (self , xp : ModuleType ):
1363
1379
with pytest .raises (TypeError ):
1364
- _ = self ._partition (xp .asarray (1 ), 1 , xp )
1380
+ _ = self ._partition (np .asarray (1 ), 1 , xp )
1365
1381
with pytest .raises (ValueError , match = "out of bounds" ):
1366
- _ = self ._partition (xp .asarray ([1 , 2 ]), 3 , xp )
1382
+ _ = self ._partition (np .asarray ([1 , 2 ]), 3 , xp )
1367
1383
1368
1384
1369
1385
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
1370
1386
class TestArgpartition (TestPartition ):
1371
1387
@classmethod
1372
1388
@override
1373
- def _partition (cls , x : Array , k : int , xp : ModuleType , axis : int | None = - 1 ):
1389
+ def _partition (cls , x : np . ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
1374
1390
if is_pydata_sparse_namespace (xp ):
1375
1391
pytest .xfail (reason = "Sparse backend has no argsort" )
1376
- indices = argpartition (x , k , axis = axis )
1377
- if x .ndim == 1 :
1378
- return x [indices ]
1392
+ arr = xp .asarray (x )
1393
+ indices = argpartition (arr , k , axis = axis )
1394
+ if axis is None :
1395
+ arr = xp .reshape (arr , shape = (- 1 ,))
1396
+ if arr .ndim == 1 :
1397
+ return arr [indices ]
1379
1398
if not hasattr (xp , "take_along_axis" ):
1380
1399
pytest .skip ("TODO: find an alternative to take_along_axis" )
1381
- return xp .take_along_axis (x , indices , axis = axis )
1400
+ return xp .take_along_axis (arr , indices , axis = axis )
0 commit comments