Skip to content

Commit 81b8ac3

Browse files
committed
Support for multi-dimensional arrays
1 parent 51ade21 commit 81b8ac3

File tree

3 files changed

+158
-46
lines changed

3 files changed

+158
-46
lines changed

src/array_api_extra/_delegation.py

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def pad(
331331
def partition(
332332
a: Array,
333333
kth: int,
334+
/,
335+
axis: int | None = -1,
334336
*,
335337
xp: ModuleType | None = None,
336338
) -> Array:
@@ -343,6 +345,9 @@ def partition(
343345
Input array.
344346
kth : int
345347
Element index to partition by.
348+
axis : int, optional
349+
Axis along which to partition. The default is -1 (the last axis).
350+
If None, the flattened array is used.
346351
xp : array_namespace, optional
347352
The standard-compatible namespace for `x`. Default: infer.
348353
@@ -354,36 +359,61 @@ def partition(
354359
# Validate inputs.
355360
if xp is None:
356361
xp = array_namespace(a)
357-
if a.ndim != 1:
358-
msg = "only 1-dimensional arrays are currently supported"
359-
raise NotImplementedError(msg)
362+
if a.ndim < 1:
363+
msg = "`a` must be at least 1-dimensional"
364+
raise TypeError(msg)
365+
if axis is None:
366+
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
367+
size = a.shape[axis]
368+
if size is None:
369+
msg = "Array dimensions must be known"
370+
raise ValueError(msg)
371+
if not (0 <= kth < size):
372+
msg = f"kth(={kth}) out of bounds [0 {size})"
373+
raise ValueError(msg)
360374

361375
# Delegate where possible.
362-
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
363-
return xp.partition(a, kth)
364-
if is_jax_namespace(xp):
365-
from jax import numpy
366-
367-
return numpy.partition(a, kth)
376+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
377+
return xp.partition(a, kth, axis=axis)
368378

369379
# Use top-k when possible:
370380
if is_torch_namespace(xp):
371-
from torch import topk
381+
if not (axis == -1 or axis == a.ndim - 1):
382+
a = xp.transpose(a, axis, -1)
372383

373-
a_left, indices_left = topk(a, kth, largest=False, sorted=False)
384+
# Get smallest `kth` elements along axis
385+
kth += 1 # HACK: we use a non-specified behavior of torch.topk:
386+
# in `a_left`, the element in the last position is the max
387+
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
388+
389+
# Build a mask to remove the selected elements
374390
mask_right = xp.ones(a.shape, dtype=bool)
375-
mask_right[indices_left] = False
376-
return xp.concat((a_left, a[mask_right]))
391+
mask_right.scatter_(dim=-1, index=indices, value=False)
392+
393+
# Remaining elements along axis
394+
a_right = a[mask_right] # 1-d array
395+
396+
# Reshape. This is valid only because we work on the last axis
397+
a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1))
398+
399+
# Concatenate the two parts along axis
400+
partitioned_array = xp.cat((a_left, a_right), dim=-1)
401+
if not (axis == -1 or axis == a.ndim - 1):
402+
partitioned_array = xp.transpose(partitioned_array, axis, -1)
403+
return partitioned_array
404+
377405
# Note: dask topk/argtopk sort the return values, so it's
378406
# not much more efficient than sorting everything when
379407
# kth is not small compared to x.size
380408

381-
return _funcs.partition(a, kth, xp=xp)
409+
return _funcs.partition(a, kth, axis=axis, xp=xp)
382410

383411

384412
def argpartition(
385413
a: Array,
386414
kth: int,
415+
/,
416+
axis: int | None = -1,
387417
*,
388418
xp: ModuleType | None = None,
389419
) -> Array:
@@ -392,10 +422,13 @@ def argpartition(
392422
393423
Parameters
394424
----------
395-
a : 1-dimensional array
425+
a : Array
396426
Input array.
397427
kth : int
398428
Element index to partition by.
429+
axis : int, optional
430+
Axis along which to partition. The default is -1 (the last axis).
431+
If None, the flattened array is used.
399432
xp : array_namespace, optional
400433
The standard-compatible namespace for `x`. Default: infer.
401434
@@ -407,29 +440,46 @@ def argpartition(
407440
# Validate inputs.
408441
if xp is None:
409442
xp = array_namespace(a)
410-
if a.ndim != 1:
411-
msg = "only 1-dimensional arrays are currently supported"
412-
raise NotImplementedError(msg)
443+
if a.ndim < 1:
444+
msg = "`a` must be at least 1-dimensional"
445+
raise TypeError(msg)
446+
if axis is None:
447+
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
448+
size = a.shape[axis]
449+
if size is None:
450+
msg = "Array dimensions must be known"
451+
raise ValueError(msg)
452+
if not (0 <= kth < size):
453+
msg = f"kth(={kth}) out of bounds [0 {size})"
454+
raise ValueError(msg)
413455

414456
# Delegate where possible.
415-
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
416-
return xp.argpartition(a, kth)
417-
if is_jax_namespace(xp):
418-
from jax import numpy
419-
420-
return numpy.argpartition(a, kth)
457+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
458+
return xp.argpartition(a, kth, axis=axis)
421459

422460
# Use top-k when possible:
423461
if is_torch_namespace(xp):
424-
from torch import topk
462+
# see `partition` above for commented details of those steps:
463+
if not (axis == -1 or axis == a.ndim - 1):
464+
a = xp.transpose(a, axis, -1)
465+
466+
kth += 1 # HACK
467+
_, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
468+
469+
mask_right = xp.ones(a.shape, dtype=bool)
470+
mask_right.scatter_(dim=-1, index=indices_left, value=False)
471+
472+
indices_right = xp.nonzero(mask_right)[-1]
473+
indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1))
474+
475+
# Concatenate the two parts along axis
476+
index_array = xp.cat((indices_left, indices_right), dim=-1)
477+
if not (axis == -1 or axis == a.ndim - 1):
478+
index_array = xp.transpose(index_array, axis, -1)
479+
return index_array
425480

426-
_, indices = topk(a, kth, largest=False, sorted=False)
427-
mask = xp.ones(a.shape, dtype=bool)
428-
mask[indices] = False
429-
indices_above = xp.arange(a.shape[0])[mask]
430-
return xp.concat((indices, indices_above))
431481
# Note: dask topk/argtopk sort the return values, so it's
432482
# not much more efficient than sorting everything when
433483
# kth is not small compared to x.size
434484

435-
return _funcs.argpartition(a, kth, xp=xp)
485+
return _funcs.argpartition(a, kth, axis=axis, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,18 +1034,22 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
10341034
def partition( # numpydoc ignore=PR01,RT01
10351035
x: Array,
10361036
kth: int, # noqa: ARG001
1037+
/,
1038+
axis: int = -1,
10371039
*,
10381040
xp: ModuleType,
10391041
) -> Array:
10401042
"""See docstring in `array_api_extra._delegation.py`."""
1041-
return xp.sort(x, stable=False)
1043+
return xp.sort(x, axis=axis, stable=False)
10421044

10431045

10441046
def argpartition( # numpydoc ignore=PR01,RT01
10451047
x: Array,
10461048
kth: int, # noqa: ARG001
1049+
/,
1050+
axis: int = -1,
10471051
*,
10481052
xp: ModuleType,
10491053
) -> Array:
10501054
"""See docstring in `array_api_extra._delegation.py`."""
1051-
return xp.argsort(x, stable=False)
1055+
return xp.argsort(x, axis=axis, stable=False)

tests/test_funcs.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
from hypothesis import given
1111
from hypothesis import strategies as st
12+
from typing_extensions import override
1213

1314
from array_api_extra import (
1415
apply_where,
@@ -32,7 +33,12 @@
3233
)
3334
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3435
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
35-
from array_api_extra._lib._utils._compat import device as get_device
36+
from array_api_extra._lib._utils._compat import (
37+
device as get_device,
38+
)
39+
from array_api_extra._lib._utils._compat import (
40+
is_pydata_sparse_namespace,
41+
)
3642
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
3743
from array_api_extra._lib._utils._typing import Array, Device
3844
from array_api_extra.testing import lazy_xp_function
@@ -1303,15 +1309,67 @@ def test_xp(self, xp: ModuleType):
13031309

13041310

13051311
class TestPartition:
1306-
def test_basic(self, xp: ModuleType):
1307-
# Using 0-dimensional array
1308-
rng = np.random.default_rng(2847)
1309-
1310-
for _ in range(100):
1311-
n = rng.integers(1, 1000)
1312-
x = xp.asarray(rng.random(size=n))
1313-
k = int(rng.integers(1, n - 1))
1314-
y = partition(x, k)
1315-
assert xp.max(y[:k]) <= xp.min(y[k:])
1316-
y = x[argpartition(x, k)]
1317-
assert xp.max(y[:k]) <= xp.min(y[k:])
1312+
@classmethod
1313+
def _assert_valid_partition(cls, x: Array, k: int, xp: ModuleType, axis: int = -1):
1314+
if x.ndim != 1 and axis == 0:
1315+
assert isinstance(x.shape[1], int)
1316+
for i in range(x.shape[1]):
1317+
cls._assert_valid_partition(x[:, i, ...], k, xp, axis=0)
1318+
elif x.ndim != 1:
1319+
axis = axis - 1 if axis != -1 else -1
1320+
assert isinstance(x.shape[0], int)
1321+
for i in range(x.shape[0]):
1322+
cls._assert_valid_partition(x[i, ...], k, xp, axis=axis)
1323+
else:
1324+
if k > 0:
1325+
assert xp.max(x[:k]) <= x[k]
1326+
assert x[k] <= xp.min(x[k:])
1327+
1328+
@classmethod
1329+
def _partition(
1330+
cls,
1331+
x: Array,
1332+
k: int,
1333+
xp: ModuleType, # noqa: ARG003
1334+
axis: int | None = -1,
1335+
):
1336+
return partition(x, k, axis=axis)
1337+
1338+
def test_1d(self, xp: ModuleType):
1339+
rng = np.random.default_rng()
1340+
for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]:
1341+
k = int(rng.integers(n))
1342+
x = xp.asarray(rng.integers(n, size=n))
1343+
self._assert_valid_partition(self._partition(x, k, xp), k, xp)
1344+
x = xp.asarray(rng.random(n))
1345+
self._assert_valid_partition(self._partition(x, k, xp), k, xp)
1346+
1347+
@pytest.mark.parametrize("ndim", [2, 3, 4, 5])
1348+
def test_nd(self, xp: ModuleType, ndim: int):
1349+
rng = np.random.default_rng()
1350+
1351+
for n in [2, 3, 5, 10, 20, 100]:
1352+
base_shape = [int(v) for v in rng.integers(1, 4, size=ndim)]
1353+
k = int(rng.integers(n))
1354+
1355+
for i in range(ndim):
1356+
shape = base_shape[:]
1357+
shape[i] = n
1358+
x = xp.asarray(rng.integers(n, size=tuple(shape)))
1359+
y = self._partition(x, k, xp, axis=i)
1360+
self._assert_valid_partition(y, k, xp, axis=i)
1361+
1362+
1363+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
1364+
class TestArgpartition(TestPartition):
1365+
@classmethod
1366+
@override
1367+
def _partition(cls, x: Array, k: int, xp: ModuleType, axis: int | None = -1):
1368+
if is_pydata_sparse_namespace(xp):
1369+
pytest.xfail(reason="Sparse backend has no argsort")
1370+
indices = argpartition(x, k, axis=axis)
1371+
if x.ndim == 1:
1372+
return x[indices]
1373+
if not hasattr(xp, "take_along_axis"):
1374+
pytest.skip("TODO: find an alternative to take_along_axis")
1375+
return xp.take_along_axis(x, indices, axis=axis)

0 commit comments

Comments
 (0)