Skip to content

Commit 6e69083

Browse files
committed
fix xfail thingy
1 parent 6efc73a commit 6e69083

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

src/array_api_extra/_delegation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ def argpartition(
437437
# Validate inputs.
438438
if xp is None:
439439
xp = array_namespace(a)
440+
if is_pydata_sparse_namespace(xp):
441+
msg = "Not implemented for sparse backend"
442+
raise NotImplementedError(msg)
440443
if a.ndim < 1:
441444
msg = "`a` must be at least 1-dimensional"
442445
raise TypeError(msg)

tests/test_funcs.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@
3636
from array_api_extra._lib._utils._compat import (
3737
device as get_device,
3838
)
39-
from array_api_extra._lib._utils._compat import (
40-
is_pydata_sparse_namespace,
41-
)
4239
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
4340
from array_api_extra._lib._utils._typing import Array, Device
4441
from array_api_extra.testing import lazy_xp_function
@@ -1344,7 +1341,7 @@ def _assert_valid_partition(
13441341
def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1):
13451342
return partition(xp.asarray(x), k, axis=axis)
13461343

1347-
def test_1d(self, xp: ModuleType):
1344+
def _test_1d(self, xp: ModuleType):
13481345
rng = np.random.default_rng()
13491346
for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]:
13501347
k = int(rng.integers(n))
@@ -1355,8 +1352,7 @@ def test_1d(self, xp: ModuleType):
13551352
y = self._partition(x2, k, xp)
13561353
self._assert_valid_partition(x2, k, y, xp)
13571354

1358-
@pytest.mark.parametrize("ndim", [2, 3, 4])
1359-
def test_nd(self, xp: ModuleType, ndim: int):
1355+
def _test_nd(self, xp: ModuleType, ndim: int):
13601356
rng = np.random.default_rng()
13611357

13621358
for n in [2, 3, 5, 10, 20, 100]:
@@ -1375,20 +1371,28 @@ def test_nd(self, xp: ModuleType, ndim: int):
13751371
y = self._partition(z, k, xp, axis=None)
13761372
self._assert_valid_partition(z, k, y, xp, axis=None)
13771373

1378-
def test_input_validation(self, xp: ModuleType):
1374+
def _test_input_validation(self, xp: ModuleType):
13791375
with pytest.raises(TypeError):
13801376
_ = self._partition(np.asarray(1), 1, xp)
13811377
with pytest.raises(ValueError, match="out of bounds"):
13821378
_ = self._partition(np.asarray([1, 2]), 3, xp)
13831379

1380+
def test_1d(self, xp: ModuleType):
1381+
self._test_1d(xp)
1382+
1383+
@pytest.mark.parametrize("ndim", [2, 3, 4])
1384+
def test_nd(self, xp: ModuleType, ndim: int):
1385+
self._test_nd(xp, ndim)
1386+
1387+
def test_input_validation(self, xp: ModuleType):
1388+
self._test_input_validation(xp)
1389+
13841390

13851391
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
13861392
class TestArgpartition(TestPartition):
13871393
@classmethod
13881394
@override
13891395
def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1):
1390-
if is_pydata_sparse_namespace(xp):
1391-
pytest.xfail(reason="Sparse backend has no argsort")
13921396
arr = xp.asarray(x)
13931397
indices = argpartition(arr, k, axis=axis)
13941398
if axis is None:
@@ -1398,3 +1402,16 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1
13981402
if not hasattr(xp, "take_along_axis"):
13991403
pytest.skip("TODO: find an alternative to take_along_axis")
14001404
return xp.take_along_axis(arr, indices, axis=axis)
1405+
1406+
@override
1407+
def test_1d(self, xp: ModuleType):
1408+
self._test_1d(xp)
1409+
1410+
@pytest.mark.parametrize("ndim", [2, 3, 4])
1411+
@override
1412+
def test_nd(self, xp: ModuleType, ndim: int):
1413+
self._test_nd(xp, ndim)
1414+
1415+
@override
1416+
def test_input_validation(self, xp: ModuleType):
1417+
self._test_input_validation(xp)

0 commit comments

Comments
 (0)