Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, nan_to_num, one_hot, pad
from ._delegation import (
argpartition,
isclose,
nan_to_num,
one_hot,
pad,
partition,
)
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand All @@ -23,6 +30,7 @@
__all__ = [
"__version__",
"apply_where",
"argpartition",
"at",
"atleast_nd",
"broadcast_shapes",
Expand All @@ -37,6 +45,7 @@
"nunique",
"one_hot",
"pad",
"partition",
"setdiff1d",
"sinc",
]
107 changes: 107 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,110 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def partition(
a: Array,
kth: int,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Return a partitioned copy of an array.

Parameters
----------
a : 1-dimensional array
Input array.
kth : int
Element index to partition by.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
partitioned_array
Array of the same type and shape as a.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim != 1:
msg = "only 1-dimensional arrays are currently supported"
raise NotImplementedError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
return xp.partition(a, kth)
if is_jax_namespace(xp):
from jax import numpy

return numpy.partition(a, kth)

# Use top-k when possible:
if is_torch_namespace(xp):
from torch import topk

a_left, indices_left = topk(a, kth, largest=False, sorted=False)
mask_right = xp.ones(a.shape, dtype=bool)
mask_right[indices_left] = False
return xp.concat((a_left, a[mask_right]))
# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.partition(a, kth, xp=xp)


def argpartition(
a: Array,
kth: int,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Perform an indirect partition along the given axis.

Parameters
----------
a : 1-dimensional array
Input array.
kth : int
Element index to partition by.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
index_array
Array of indices that partition `a` along the specified axis.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim != 1:
msg = "only 1-dimensional arrays are currently supported"
raise NotImplementedError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
return xp.argpartition(a, kth)
if is_jax_namespace(xp):
from jax import numpy

return numpy.argpartition(a, kth)

# Use top-k when possible:
if is_torch_namespace(xp):
from torch import topk

_, indices = topk(a, kth, largest=False, sorted=False)
mask = xp.ones(a.shape, dtype=bool)
mask[indices] = False
indices_above = xp.arange(a.shape[0])[mask]
return xp.concat((indices, indices_above))
# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, xp=xp)
20 changes: 20 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,23 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def partition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.sort(x, stable=False)


def argpartition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, stable=False)
17 changes: 17 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from array_api_extra import (
apply_where,
argpartition,
at,
atleast_nd,
broadcast_shapes,
Expand All @@ -25,6 +26,7 @@
nunique,
one_hot,
pad,
partition,
setdiff1d,
sinc,
)
Expand Down Expand Up @@ -1298,3 +1300,18 @@ def test_device(self, xp: ModuleType, device: Device):

def test_xp(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestPartition:
def test_basic(self, xp: ModuleType):
# Using 0-dimensional array
rng = np.random.default_rng(2847)

for _ in range(100):
n = rng.integers(1, 1000)
x = xp.asarray(rng.random(size=n))
k = int(rng.integers(1, n - 1))
y = partition(x, k)
assert xp.max(y[:k]) <= xp.min(y[k:])
y = x[argpartition(x, k)]
assert xp.max(y[:k]) <= xp.min(y[k:])
Loading