Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, nan_to_num, one_hot, pad
from ._delegation import (
argpartition,
isclose,
nan_to_num,
one_hot,
pad,
partition,
)
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand All @@ -23,6 +30,7 @@
__all__ = [
"__version__",
"apply_where",
"argpartition",
"at",
"atleast_nd",
"broadcast_shapes",
Expand All @@ -37,6 +45,7 @@
"nunique",
"one_hot",
"pad",
"partition",
"setdiff1d",
"sinc",
]
157 changes: 157 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,160 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def partition(
a: Array,
kth: int,
/,
axis: int | None = -1,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Return a partitioned copy of an array.

Parameters
----------
a : 1-dimensional array
Input array.
kth : int
Element index to partition by.
axis : int, optional
Axis along which to partition. The default is -1 (the last axis).
If None, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
partitioned_array
Array of the same type and shape as a.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
if axis is None:
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
size = a.shape[axis]
if size is None:
msg = "Array dimensions must be known"
raise ValueError(msg)
if not (0 <= kth < size):
msg = f"kth(={kth}) out of bounds [0 {size})"
raise ValueError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.partition(a, kth, axis=axis)

# Use top-k when possible:
if is_torch_namespace(xp):
if not (axis == -1 or axis == a.ndim - 1):
a = xp.transpose(a, axis, -1)

# Get smallest `kth` elements along axis
kth += 1 # HACK: we use a non-specified behavior of torch.topk:
# in `a_left`, the element in the last position is the max
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, I would rather not rely on undocumented behaviour. Is there an alternative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair ^^

Three options:

  • add an assert a_left.max() == a_left[k]
  • We can just re-run the same logic with kth=1 and largest=True. Impact on perfs is probably 10 to 100% slower depending on the input. But it doens't add a lot of logic
  • We can do a if a_left.max() != a_left[k]: swap_max_with_last_element(a_left, axis=-1) => requires to implement swap_max_with_last_element (and the equivalent for argsort).

I vote for 1 because I'm lazy but I like perf :p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: wait I need to rethink something about numpy.partition specs...

Copy link
Contributor Author

@cakedev0 cakedev0 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So! I rewrote entirely this section, it now relies on torch.kthvalue and is very aligned with numpy's behavior.

On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).

I will maybe open an issue on numpy to ask for some clarification.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).

It might be worth contributing this consideration to the array API spec discussion:


# Build a mask to remove the selected elements
mask_right = xp.ones(a.shape, dtype=bool)
mask_right.scatter_(dim=-1, index=indices, value=False)

# Remaining elements along axis
a_right = a[mask_right] # 1-d array

# Reshape. This is valid only because we work on the last axis
a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1))

# Concatenate the two parts along axis
partitioned_array = xp.cat((a_left, a_right), dim=-1)
if not (axis == -1 or axis == a.ndim - 1):
partitioned_array = xp.transpose(partitioned_array, axis, -1)
return partitioned_array

# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.partition(a, kth, axis=axis, xp=xp)


def argpartition(
a: Array,
kth: int,
/,
axis: int | None = -1,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Perform an indirect partition along the given axis.

Parameters
----------
a : Array
Input array.
kth : int
Element index to partition by.
axis : int, optional
Axis along which to partition. The default is -1 (the last axis).
If None, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
index_array
Array of indices that partition `a` along the specified axis.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
if axis is None:
return partition(xp.reshape(a, -1), kth, axis=0, xp=xp)
size = a.shape[axis]
if size is None:
msg = "Array dimensions must be known"
raise ValueError(msg)
if not (0 <= kth < size):
msg = f"kth(={kth}) out of bounds [0 {size})"
raise ValueError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.argpartition(a, kth, axis=axis)

# Use top-k when possible:
if is_torch_namespace(xp):
# see `partition` above for commented details of those steps:
if not (axis == -1 or axis == a.ndim - 1):
a = xp.transpose(a, axis, -1)

kth += 1 # HACK
_, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False)

mask_right = xp.ones(a.shape, dtype=bool)
mask_right.scatter_(dim=-1, index=indices_left, value=False)

indices_right = xp.nonzero(mask_right)[-1]
indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1))

# Concatenate the two parts along axis
index_array = xp.cat((indices_left, indices_right), dim=-1)
if not (axis == -1 or axis == a.ndim - 1):
index_array = xp.transpose(index_array, axis, -1)
return index_array

# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, axis=axis, xp=xp)
24 changes: 24 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,27 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def partition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
/,
axis: int = -1,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.sort(x, axis=axis, stable=False)


def argpartition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
/,
axis: int = -1,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, axis=axis, stable=False)
83 changes: 82 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import pytest
from hypothesis import given
from hypothesis import strategies as st
from typing_extensions import override

from array_api_extra import (
apply_where,
argpartition,
at,
atleast_nd,
broadcast_shapes,
Expand All @@ -25,12 +27,18 @@
nunique,
one_hot,
pad,
partition,
setdiff1d,
sinc,
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._compat import (
device as get_device,
)
from array_api_extra._lib._utils._compat import (
is_pydata_sparse_namespace,
)
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -1298,3 +1306,76 @@ def test_device(self, xp: ModuleType, device: Device):

def test_xp(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestPartition:
@classmethod
def _assert_valid_partition(cls, x: Array, k: int, xp: ModuleType, axis: int = -1):
if x.ndim != 1 and axis == 0:
assert isinstance(x.shape[1], int)
for i in range(x.shape[1]):
cls._assert_valid_partition(x[:, i, ...], k, xp, axis=0)
elif x.ndim != 1:
axis = axis - 1 if axis != -1 else -1
assert isinstance(x.shape[0], int)
for i in range(x.shape[0]):
cls._assert_valid_partition(x[i, ...], k, xp, axis=axis)
else:
if k > 0:
assert xp.max(x[:k]) <= x[k]
assert x[k] <= xp.min(x[k:])

@classmethod
def _partition(
cls,
x: Array,
k: int,
xp: ModuleType, # noqa: ARG003
axis: int | None = -1,
):
return partition(x, k, axis=axis)

def test_1d(self, xp: ModuleType):
rng = np.random.default_rng()
for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]:
k = int(rng.integers(n))
x = xp.asarray(rng.integers(n, size=n))
self._assert_valid_partition(self._partition(x, k, xp), k, xp)
x = xp.asarray(rng.random(n))
self._assert_valid_partition(self._partition(x, k, xp), k, xp)

@pytest.mark.parametrize("ndim", [2, 3, 4, 5])
def test_nd(self, xp: ModuleType, ndim: int):
rng = np.random.default_rng()

for n in [2, 3, 5, 10, 20, 100]:
base_shape = [int(v) for v in rng.integers(1, 4, size=ndim)]
k = int(rng.integers(n))

for i in range(ndim):
shape = base_shape[:]
shape[i] = n
x = xp.asarray(rng.integers(n, size=tuple(shape)))
y = self._partition(x, k, xp, axis=i)
self._assert_valid_partition(y, k, xp, axis=i)

def test_input_validation(self, xp: ModuleType):
with pytest.raises(TypeError):
_ = self._partition(xp.asarray(1), 1, xp)
with pytest.raises(ValueError, match="out of bounds"):
_ = self._partition(xp.asarray([1, 2]), 3, xp)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
class TestArgpartition(TestPartition):
@classmethod
@override
def _partition(cls, x: Array, k: int, xp: ModuleType, axis: int | None = -1):
if is_pydata_sparse_namespace(xp):
pytest.xfail(reason="Sparse backend has no argsort")
indices = argpartition(x, k, axis=axis)
if x.ndim == 1:
return x[indices]
if not hasattr(xp, "take_along_axis"):
pytest.skip("TODO: find an alternative to take_along_axis")
return xp.take_along_axis(x, indices, axis=axis)