@@ -1310,41 +1310,52 @@ def test_xp(self, xp: ModuleType):
13101310
13111311class TestPartition :
13121312 @classmethod
1313- def _assert_valid_partition (cls , x : Array , k : int , xp : ModuleType , axis : int = - 1 ):
1314- if x .ndim != 1 and axis == 0 :
1315- assert isinstance (x .shape [1 ], int )
1316- for i in range (x .shape [1 ]):
1317- cls ._assert_valid_partition (x [:, i , ...], k , xp , axis = 0 )
1318- elif x .ndim != 1 :
1313+ def _assert_valid_partition (
1314+ cls ,
1315+ x_np : np .ndarray | None ,
1316+ k : int ,
1317+ y : Array ,
1318+ xp : ModuleType ,
1319+ axis : int | None = - 1 ,
1320+ ):
1321+ """
1322+ x : input array
1323+ k : int
1324+ y : output array returned by the partition function to test
1325+ """
1326+ if x_np is not None :
1327+ assert y .shape == np .partition (x_np , k , axis = axis ).shape
1328+ if y .ndim != 1 and axis == 0 :
1329+ assert isinstance (y .shape [1 ], int )
1330+ for i in range (y .shape [1 ]):
1331+ cls ._assert_valid_partition (None , k , y [:, i , ...], xp , axis = 0 )
1332+ elif y .ndim != 1 :
1333+ assert axis is not None
13191334 axis = axis - 1 if axis != - 1 else - 1
1320- assert isinstance (x .shape [0 ], int )
1321- for i in range (x .shape [0 ]):
1322- cls ._assert_valid_partition (x [i , ...], k , xp , axis = axis )
1335+ assert isinstance (y .shape [0 ], int )
1336+ for i in range (y .shape [0 ]):
1337+ cls ._assert_valid_partition (None , k , y [i , ...], xp , axis = axis )
13231338 else :
13241339 if k > 0 :
1325- assert xp .max (x [:k ]) <= x [k ]
1326- assert x [k ] <= xp .min (x [k :])
1340+ assert xp .max (y [:k ]) <= y [k ]
1341+ assert y [k ] <= xp .min (y [k :])
13271342
13281343 @classmethod
1329- def _partition (
1330- cls ,
1331- x : Array ,
1332- k : int ,
1333- xp : ModuleType , # noqa: ARG003
1334- axis : int | None = - 1 ,
1335- ):
1336- return partition (x , k , axis = axis )
1344+ def _partition (cls , x : np .ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
1345+ return partition (xp .asarray (x ), k , axis = axis )
13371346
13381347 def test_1d (self , xp : ModuleType ):
13391348 rng = np .random .default_rng ()
13401349 for n in [2 , 3 , 4 , 5 , 7 , 10 , 20 , 50 , 100 , 1_000 ]:
13411350 k = int (rng .integers (n ))
1342- x = xp .asarray (rng .integers (n , size = n ))
1343- self ._assert_valid_partition (self ._partition (x , k , xp ), k , xp )
1344- x = xp .asarray (rng .random (n ))
1345- self ._assert_valid_partition (self ._partition (x , k , xp ), k , xp )
1346-
1347- @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 , 5 ])
1351+ x1 = rng .integers (n , size = n )
1352+ y = self ._partition (x1 , k , xp )
1353+ self ._assert_valid_partition (x1 , k , y , xp )
1354+ x2 = rng .random (n )
1355+ y = self ._partition (x2 , k , xp )
1356+ self ._assert_valid_partition (x2 , k , y , xp )
1357+
1358+ @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
13481359 def test_nd (self , xp : ModuleType , ndim : int ):
13491360 rng = np .random .default_rng ()
13501361
@@ -1355,27 +1366,35 @@ def test_nd(self, xp: ModuleType, ndim: int):
13551366 for i in range (ndim ):
13561367 shape = base_shape [:]
13571368 shape [i ] = n
1358- x = xp . asarray ( rng .integers (n , size = tuple (shape ) ))
1369+ x = rng .integers (n , size = tuple (shape ))
13591370 y = self ._partition (x , k , xp , axis = i )
1360- self ._assert_valid_partition (y , k , xp , axis = i )
1371+ self ._assert_valid_partition (x , k , y , xp , axis = i )
1372+
1373+ z = rng .random (tuple (base_shape ))
1374+ k = int (rng .integers (z .size ))
1375+ y = self ._partition (z , k , xp , axis = None )
1376+ self ._assert_valid_partition (z , k , y , xp , axis = None )
13611377
13621378 def test_input_validation (self , xp : ModuleType ):
13631379 with pytest .raises (TypeError ):
1364- _ = self ._partition (xp .asarray (1 ), 1 , xp )
1380+ _ = self ._partition (np .asarray (1 ), 1 , xp )
13651381 with pytest .raises (ValueError , match = "out of bounds" ):
1366- _ = self ._partition (xp .asarray ([1 , 2 ]), 3 , xp )
1382+ _ = self ._partition (np .asarray ([1 , 2 ]), 3 , xp )
13671383
13681384
13691385@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
13701386class TestArgpartition (TestPartition ):
13711387 @classmethod
13721388 @override
1373- def _partition (cls , x : Array , k : int , xp : ModuleType , axis : int | None = - 1 ):
1389+ def _partition (cls , x : np . ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
13741390 if is_pydata_sparse_namespace (xp ):
13751391 pytest .xfail (reason = "Sparse backend has no argsort" )
1376- indices = argpartition (x , k , axis = axis )
1377- if x .ndim == 1 :
1378- return x [indices ]
1392+ arr = xp .asarray (x )
1393+ indices = argpartition (arr , k , axis = axis )
1394+ if axis is None :
1395+ arr = xp .reshape (arr , shape = (- 1 ,))
1396+ if arr .ndim == 1 :
1397+ return arr [indices ]
13791398 if not hasattr (xp , "take_along_axis" ):
13801399 pytest .skip ("TODO: find an alternative to take_along_axis" )
1381- return xp .take_along_axis (x , indices , axis = axis )
1400+ return xp .take_along_axis (arr , indices , axis = axis )
0 commit comments