@@ -1316,7 +1316,7 @@ def _assert_valid_partition(
1316
1316
axis : int | None = - 1 ,
1317
1317
):
1318
1318
"""
1319
- x : input array
1319
+ x_np : input array
1320
1320
k : int
1321
1321
y : output array returned by the partition function to test
1322
1322
"""
@@ -1397,11 +1397,31 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1
1397
1397
indices = argpartition (arr , k , axis = axis )
1398
1398
if axis is None :
1399
1399
arr = xp .reshape (arr , shape = (- 1 ,))
1400
+ return arr [indices ]
1401
+ if arr .ndim == 1 :
1402
+ return arr [indices ]
1403
+ return cls ._take_along_axis (arr , indices , axis = axis , xp = xp )
1404
+
1405
+ @classmethod
1406
+ def _take_along_axis (cls , arr : Array , indices : Array , axis : int , xp : ModuleType ):
1407
+ if hasattr (xp , "take_along_axis" ):
1408
+ return xp .take_along_axis (arr , indices , axis = axis )
1400
1409
if arr .ndim == 1 :
1401
1410
return arr [indices ]
1402
- if not hasattr (xp , "take_along_axis" ):
1403
- pytest .skip ("TODO: find an alternative to take_along_axis" )
1404
- return xp .take_along_axis (arr , indices , axis = axis )
1411
+ if axis == 0 :
1412
+ assert isinstance (arr .shape [1 ], int )
1413
+ arrs = []
1414
+ for i in range (arr .shape [1 ]):
1415
+ arrs .append (cls ._take_along_axis (arr [:, i , ...], indices [:, i , ...],
1416
+ axis = 0 , xp = xp ))
1417
+ return xp .stack (arrs , axis = 1 )
1418
+ axis = axis - 1 if axis != - 1 else - 1
1419
+ assert isinstance (arr .shape [0 ], int )
1420
+ arrs = []
1421
+ for i in range (arr .shape [0 ]):
1422
+ arrs .append (cls ._take_along_axis (arr [i , ...], indices [i , ...],
1423
+ axis = axis , xp = xp ))
1424
+ return xp .stack (arrs , axis = 0 )
1405
1425
1406
1426
@override
1407
1427
def test_1d (self , xp : ModuleType ):
0 commit comments