Skip to content

Commit 6efc73a

Browse files
committed
improved tests & coverage
1 parent 74c509f commit 6efc73a

File tree

2 files changed

+55
-36
lines changed

2 files changed

+55
-36
lines changed

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def partition(
363363
msg = "`a` must be at least 1-dimensional"
364364
raise TypeError(msg)
365365
if axis is None:
366-
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
366+
return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
367367
(size,) = eager_shape(a, axis)
368368
if not (0 <= kth < size):
369369
msg = f"kth(={kth}) out of bounds [0 {size})"
@@ -441,7 +441,7 @@ def argpartition(
441441
msg = "`a` must be at least 1-dimensional"
442442
raise TypeError(msg)
443443
if axis is None:
444-
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
444+
return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
445445
(size,) = eager_shape(a, axis)
446446
if not (0 <= kth < size):
447447
msg = f"kth(={kth}) out of bounds [0 {size})"

tests/test_funcs.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,41 +1310,52 @@ def test_xp(self, xp: ModuleType):
13101310

13111311
class TestPartition:
13121312
@classmethod
1313-
def _assert_valid_partition(cls, x: Array, k: int, xp: ModuleType, axis: int = -1):
1314-
if x.ndim != 1 and axis == 0:
1315-
assert isinstance(x.shape[1], int)
1316-
for i in range(x.shape[1]):
1317-
cls._assert_valid_partition(x[:, i, ...], k, xp, axis=0)
1318-
elif x.ndim != 1:
1313+
def _assert_valid_partition(
1314+
cls,
1315+
x_np: np.ndarray | None,
1316+
k: int,
1317+
y: Array,
1318+
xp: ModuleType,
1319+
axis: int | None = -1,
1320+
):
1321+
"""
1322+
x : input array
1323+
k : int
1324+
y : output array returned by the partition function to test
1325+
"""
1326+
if x_np is not None:
1327+
assert y.shape == np.partition(x_np, k, axis=axis).shape
1328+
if y.ndim != 1 and axis == 0:
1329+
assert isinstance(y.shape[1], int)
1330+
for i in range(y.shape[1]):
1331+
cls._assert_valid_partition(None, k, y[:, i, ...], xp, axis=0)
1332+
elif y.ndim != 1:
1333+
assert axis is not None
13191334
axis = axis - 1 if axis != -1 else -1
1320-
assert isinstance(x.shape[0], int)
1321-
for i in range(x.shape[0]):
1322-
cls._assert_valid_partition(x[i, ...], k, xp, axis=axis)
1335+
assert isinstance(y.shape[0], int)
1336+
for i in range(y.shape[0]):
1337+
cls._assert_valid_partition(None, k, y[i, ...], xp, axis=axis)
13231338
else:
13241339
if k > 0:
1325-
assert xp.max(x[:k]) <= x[k]
1326-
assert x[k] <= xp.min(x[k:])
1340+
assert xp.max(y[:k]) <= y[k]
1341+
assert y[k] <= xp.min(y[k:])
13271342

13281343
@classmethod
1329-
def _partition(
1330-
cls,
1331-
x: Array,
1332-
k: int,
1333-
xp: ModuleType, # noqa: ARG003
1334-
axis: int | None = -1,
1335-
):
1336-
return partition(x, k, axis=axis)
1344+
def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1):
1345+
return partition(xp.asarray(x), k, axis=axis)
13371346

13381347
def test_1d(self, xp: ModuleType):
13391348
rng = np.random.default_rng()
13401349
for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]:
13411350
k = int(rng.integers(n))
1342-
x = xp.asarray(rng.integers(n, size=n))
1343-
self._assert_valid_partition(self._partition(x, k, xp), k, xp)
1344-
x = xp.asarray(rng.random(n))
1345-
self._assert_valid_partition(self._partition(x, k, xp), k, xp)
1346-
1347-
@pytest.mark.parametrize("ndim", [2, 3, 4, 5])
1351+
x1 = rng.integers(n, size=n)
1352+
y = self._partition(x1, k, xp)
1353+
self._assert_valid_partition(x1, k, y, xp)
1354+
x2 = rng.random(n)
1355+
y = self._partition(x2, k, xp)
1356+
self._assert_valid_partition(x2, k, y, xp)
1357+
1358+
@pytest.mark.parametrize("ndim", [2, 3, 4])
13481359
def test_nd(self, xp: ModuleType, ndim: int):
13491360
rng = np.random.default_rng()
13501361

@@ -1355,27 +1366,35 @@ def test_nd(self, xp: ModuleType, ndim: int):
13551366
for i in range(ndim):
13561367
shape = base_shape[:]
13571368
shape[i] = n
1358-
x = xp.asarray(rng.integers(n, size=tuple(shape)))
1369+
x = rng.integers(n, size=tuple(shape))
13591370
y = self._partition(x, k, xp, axis=i)
1360-
self._assert_valid_partition(y, k, xp, axis=i)
1371+
self._assert_valid_partition(x, k, y, xp, axis=i)
1372+
1373+
z = rng.random(tuple(base_shape))
1374+
k = int(rng.integers(z.size))
1375+
y = self._partition(z, k, xp, axis=None)
1376+
self._assert_valid_partition(z, k, y, xp, axis=None)
13611377

13621378
def test_input_validation(self, xp: ModuleType):
13631379
with pytest.raises(TypeError):
1364-
_ = self._partition(xp.asarray(1), 1, xp)
1380+
_ = self._partition(np.asarray(1), 1, xp)
13651381
with pytest.raises(ValueError, match="out of bounds"):
1366-
_ = self._partition(xp.asarray([1, 2]), 3, xp)
1382+
_ = self._partition(np.asarray([1, 2]), 3, xp)
13671383

13681384

13691385
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
13701386
class TestArgpartition(TestPartition):
13711387
@classmethod
13721388
@override
1373-
def _partition(cls, x: Array, k: int, xp: ModuleType, axis: int | None = -1):
1389+
def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1):
13741390
if is_pydata_sparse_namespace(xp):
13751391
pytest.xfail(reason="Sparse backend has no argsort")
1376-
indices = argpartition(x, k, axis=axis)
1377-
if x.ndim == 1:
1378-
return x[indices]
1392+
arr = xp.asarray(x)
1393+
indices = argpartition(arr, k, axis=axis)
1394+
if axis is None:
1395+
arr = xp.reshape(arr, shape=(-1,))
1396+
if arr.ndim == 1:
1397+
return arr[indices]
13791398
if not hasattr(xp, "take_along_axis"):
13801399
pytest.skip("TODO: find an alternative to take_along_axis")
1381-
return xp.take_along_axis(x, indices, axis=axis)
1400+
return xp.take_along_axis(arr, indices, axis=axis)

0 commit comments

Comments
 (0)