Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ def argpartition(
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if is_pydata_sparse_namespace(xp):
msg = "Not implemented for sparse backend"
raise NotImplementedError(msg)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
Expand Down
35 changes: 26 additions & 9 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
from array_api_extra._lib._utils._compat import (
device as get_device,
)
from array_api_extra._lib._utils._compat import (
is_pydata_sparse_namespace,
)
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -1344,7 +1341,7 @@ def _assert_valid_partition(
def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1):
return partition(xp.asarray(x), k, axis=axis)

def test_1d(self, xp: ModuleType):
def _test_1d(self, xp: ModuleType):
rng = np.random.default_rng()
for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]:
k = int(rng.integers(n))
Expand All @@ -1355,8 +1352,7 @@ def test_1d(self, xp: ModuleType):
y = self._partition(x2, k, xp)
self._assert_valid_partition(x2, k, y, xp)

@pytest.mark.parametrize("ndim", [2, 3, 4])
def test_nd(self, xp: ModuleType, ndim: int):
def _test_nd(self, xp: ModuleType, ndim: int):
rng = np.random.default_rng()

for n in [2, 3, 5, 10, 20, 100]:
Expand All @@ -1375,20 +1371,28 @@ def test_nd(self, xp: ModuleType, ndim: int):
y = self._partition(z, k, xp, axis=None)
self._assert_valid_partition(z, k, y, xp, axis=None)

def test_input_validation(self, xp: ModuleType):
def _test_input_validation(self, xp: ModuleType):
with pytest.raises(TypeError):
_ = self._partition(np.asarray(1), 1, xp)
with pytest.raises(ValueError, match="out of bounds"):
_ = self._partition(np.asarray([1, 2]), 3, xp)

def test_1d(self, xp: ModuleType):
self._test_1d(xp)

@pytest.mark.parametrize("ndim", [2, 3, 4])
def test_nd(self, xp: ModuleType, ndim: int):
self._test_nd(xp, ndim)

def test_input_validation(self, xp: ModuleType):
self._test_input_validation(xp)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
class TestArgpartition(TestPartition):
@classmethod
@override
def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1):
if is_pydata_sparse_namespace(xp):
pytest.xfail(reason="Sparse backend has no argsort")
arr = xp.asarray(x)
indices = argpartition(arr, k, axis=axis)
if axis is None:
Expand All @@ -1398,3 +1402,16 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1
if not hasattr(xp, "take_along_axis"):
pytest.skip("TODO: find an alternative to take_along_axis")
return xp.take_along_axis(arr, indices, axis=axis)

@override
def test_1d(self, xp: ModuleType):
self._test_1d(xp)

@pytest.mark.parametrize("ndim", [2, 3, 4])
@override
def test_nd(self, xp: ModuleType, ndim: int):
self._test_nd(xp, ndim)

@override
def test_input_validation(self, xp: ModuleType):
self._test_input_validation(xp)