Skip to content

Commit c2827da

Browse files
committed
dask support in argpart tests
1 parent 579b3bc commit c2827da

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

tests/test_funcs.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,7 +1316,7 @@ def _assert_valid_partition(
13161316
axis: int | None = -1,
13171317
):
13181318
"""
1319-
x : input array
1319+
x_np : input array
13201320
k : int
13211321
y : output array returned by the partition function to test
13221322
"""
@@ -1397,11 +1397,31 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1
13971397
indices = argpartition(arr, k, axis=axis)
13981398
if axis is None:
13991399
arr = xp.reshape(arr, shape=(-1,))
1400+
return arr[indices]
1401+
if arr.ndim == 1:
1402+
return arr[indices]
1403+
return cls._take_along_axis(arr, indices, axis=axis, xp=xp)
1404+
1405+
@classmethod
1406+
def _take_along_axis(cls, arr: Array, indices: Array, axis: int, xp: ModuleType):
1407+
if hasattr(xp, "take_along_axis"):
1408+
return xp.take_along_axis(arr, indices, axis=axis)
14001409
if arr.ndim == 1:
14011410
return arr[indices]
1402-
if not hasattr(xp, "take_along_axis"):
1403-
pytest.skip("TODO: find an alternative to take_along_axis")
1404-
return xp.take_along_axis(arr, indices, axis=axis)
1411+
if axis == 0:
1412+
assert isinstance(arr.shape[1], int)
1413+
arrs = []
1414+
for i in range(arr.shape[1]):
1415+
arrs.append(cls._take_along_axis(arr[:, i, ...], indices[:, i, ...],
1416+
axis=0, xp=xp))
1417+
return xp.stack(arrs, axis=1)
1418+
axis = axis - 1 if axis != -1 else -1
1419+
assert isinstance(arr.shape[0], int)
1420+
arrs = []
1421+
for i in range(arr.shape[0]):
1422+
arrs.append(cls._take_along_axis(arr[i, ...], indices[i, ...],
1423+
axis=axis, xp=xp))
1424+
return xp.stack(arrs, axis=0)
14051425

14061426
@override
14071427
def test_1d(self, xp: ModuleType):

0 commit comments

Comments
 (0)