diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index d901a7f9..87ef6986 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,7 @@ """Extra array functions built on top of the array API standard.""" from ._delegation import ( + argpartition, atleast_nd, cov, expand_dims, @@ -8,6 +9,7 @@ nan_to_num, one_hot, pad, + partition, sinc, ) from ._lib._at import at @@ -28,6 +30,7 @@ __all__ = [ "__version__", "apply_where", + "argpartition", "at", "atleast_nd", "broadcast_shapes", @@ -42,6 +45,7 @@ "nunique", "one_hot", "pad", + "partition", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 75b6caeb..7f467366 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -15,7 +15,7 @@ is_torch_namespace, ) from ._lib._utils._compat import device as get_device -from ._lib._utils._helpers import asarrays +from ._lib._utils._helpers import asarrays, eager_shape from ._lib._utils._typing import Array, DType __all__ = [ @@ -645,3 +645,194 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.sinc(x) return _funcs.sinc(x, xp=xp) + + +def partition( + a: Array, + kth: int, + /, + axis: int | None = -1, + *, + xp: ModuleType | None = None, +) -> Array: + """ + Return a partitioned copy of an array. + + Creates a copy of the array and partially sorts it in such a way that the value + of the element in k-th position is in the position it would be in a sorted array. + In the output array, all elements smaller than the k-th element are located to + the left of this element and all equal or greater are located to its right. + The ordering of the elements in the two partitions on the either side of + the k-th element in the output array is undefined. + + Parameters + ---------- + a : Array + Input array. + kth : int + Element index to partition by. + axis : int, optional + Axis along which to partition. The default is ``-1`` (the last axis). + If ``None``, the flattened array is used. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + partitioned_array + Array of the same type and shape as `a`. + + Notes + ----- + If `xp` implements ``partition`` or an equivalent function + (e.g. ``topk`` for torch), complexity will likely be O(n). + If not, this function simply calls ``xp.sort`` and complexity is O(n log n). + """ + # Validate inputs. + if xp is None: + xp = array_namespace(a) + if a.ndim < 1: + msg = "`a` must be at least 1-dimensional" + raise TypeError(msg) + if axis is None: + return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) + (size,) = eager_shape(a, axis) + if not (0 <= kth < size): + msg = f"kth(={kth}) out of bounds [0 {size})" + raise ValueError(msg) + + # Delegate where possible. + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): + return xp.partition(a, kth, axis=axis) + + # Use top-k when possible: + if is_torch_namespace(xp): + if not (axis == -1 or axis == a.ndim - 1): + a = xp.transpose(a, axis, -1) + + out = xp.empty_like(a) + ranks = xp.arange(a.shape[-1]).expand_as(a) + + split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) + del indices # indices won't be used => del ASAP to reduce peak memory usage + + # fill the left-side of the partition + mask_src = a < split_value + n_left = mask_src.sum(dim=-1, keepdim=True) + mask_dest = ranks < n_left + out[mask_dest] = a[mask_src] + + # fill the middle of the partition + mask_src = a == split_value + n_left += mask_src.sum(dim=-1, keepdim=True) + mask_dest ^= ranks < n_left + out[mask_dest] = a[mask_src] + + # fill the right-side of the partition + mask_src = a > split_value + mask_dest = ranks >= n_left + out[mask_dest] = a[mask_src] + + if not (axis == -1 or axis == a.ndim - 1): + out = xp.transpose(out, axis, -1) + return out + + # Note: dask topk/argtopk sort the return values, so it's + # not much more efficient than sorting everything when + # kth is not small compared to x.size + + return _funcs.partition(a, kth, axis=axis, xp=xp) + + +def argpartition( + a: Array, + kth: int, + /, + axis: int | None = -1, + *, + xp: ModuleType | None = None, +) -> Array: + """ + Perform an indirect partition along the given axis. + + It returns an array of indices of the same shape as `a` that + index data along the given axis in partitioned order. + + Parameters + ---------- + a : Array + Input array. + kth : int + Element index to partition by. + axis : int, optional + Axis along which to partition. The default is ``-1`` (the last axis). + If ``None``, the flattened array is used. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + index_array + Array of indices that partition `a` along the specified axis. + + Notes + ----- + If `xp` implements ``argpartition`` or an equivalent function + e.g. ``topk`` for torch), complexity will likely be O(n). + If not, this function simply calls ``xp.argsort`` and complexity is O(n log n). + """ + # Validate inputs. + if xp is None: + xp = array_namespace(a) + if is_pydata_sparse_namespace(xp): + msg = "Not implemented for sparse backend: no argsort" + raise NotImplementedError(msg) + if a.ndim < 1: + msg = "`a` must be at least 1-dimensional" + raise TypeError(msg) + if axis is None: + return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) + (size,) = eager_shape(a, axis) + if not (0 <= kth < size): + msg = f"kth(={kth}) out of bounds [0 {size})" + raise ValueError(msg) + + # Delegate where possible. + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): + return xp.argpartition(a, kth, axis=axis) + + # Use top-k when possible: + if is_torch_namespace(xp): + # see `partition` above for commented details of those steps: + if not (axis == -1 or axis == a.ndim - 1): + a = xp.transpose(a, axis, -1) + + ranks = xp.arange(a.shape[-1]).expand_as(a) + out = xp.empty_like(ranks) + + split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) + del indices # indices won't be used => del ASAP to reduce peak memory usage + + mask_src = a < split_value + n_left = mask_src.sum(dim=-1, keepdim=True) + mask_dest = ranks < n_left + out[mask_dest] = ranks[mask_src] + + mask_src = a == split_value + n_left += mask_src.sum(dim=-1, keepdim=True) + mask_dest ^= ranks < n_left + out[mask_dest] = ranks[mask_src] + + mask_src = a > split_value + mask_dest = ranks >= n_left + out[mask_dest] = ranks[mask_src] + + if not (axis == -1 or axis == a.ndim - 1): + out = xp.transpose(out, axis, -1) + return out + + # Note: dask topk/argtopk sort the return values, so it's + # not much more efficient than sorting everything when + # kth is not small compared to x.size + + return _funcs.argpartition(a, kth, axis=axis, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 65c78b04..fe52305f 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -777,3 +777,27 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y + + +def partition( # numpydoc ignore=PR01,RT01 + x: Array, + kth: int, # noqa: ARG001 + /, + axis: int = -1, + *, + xp: ModuleType, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + return xp.sort(x, axis=axis, stable=False) + + +def argpartition( # numpydoc ignore=PR01,RT01 + x: Array, + kth: int, # noqa: ARG001 + /, + axis: int = -1, + *, + xp: ModuleType, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + return xp.argsort(x, axis=axis, stable=False) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 6dd94a38..fbe986a1 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -250,7 +250,7 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]: yield *i, j -def eager_shape(x: Array, /) -> tuple[int, ...]: +def eager_shape(x: Array, /, axis: int | None = None) -> tuple[int, ...]: """ Return shape of an array. Raise if shape is not fully defined. @@ -258,6 +258,8 @@ def eager_shape(x: Array, /) -> tuple[int, ...]: ---------- x : Array Input array. + axis : int, optional + If provided, only returns the tuple (shape[axis],). Returns ------- @@ -265,7 +267,14 @@ def eager_shape(x: Array, /) -> tuple[int, ...]: Shape of the array. """ shape = x.shape - # Dask arrays uses non-standard NaN instead of None + if axis is not None: + s = shape[axis] + # Dask arrays uses non-standard NaN instead of None + if s is None or math.isnan(s): + msg = f"Unsupported lazy shape for axis {axis}" + raise TypeError(msg) + return (s,) + if any(s is None or math.isnan(s) for s in shape): msg = "Unsupported lazy shape" raise TypeError(msg) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ef627eb1..ac7f4d8c 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -9,9 +9,11 @@ import pytest from hypothesis import given from hypothesis import strategies as st +from typing_extensions import override from array_api_extra import ( apply_where, + argpartition, at, atleast_nd, broadcast_shapes, @@ -25,12 +27,15 @@ nunique, one_hot, pad, + partition, setdiff1d, sinc, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal -from array_api_extra._lib._utils._compat import device as get_device +from array_api_extra._lib._utils._compat import ( + device as get_device, +) from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -1322,3 +1327,139 @@ def test_device(self, xp: ModuleType, device: Device): def test_xp(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + + +class TestPartition: + @classmethod + def _assert_valid_partition( + cls, + x_np: np.ndarray | None, + k: int, + y: Array, + xp: ModuleType, + axis: int | None = -1, + ): + """ + x_np : input array + k : int + y : output array returned by the partition function to test + """ + if x_np is not None: + assert y.shape == np.partition(x_np, k, axis=axis).shape + if y.ndim != 1 and axis == 0: + assert isinstance(y.shape[1], int) + for i in range(y.shape[1]): + cls._assert_valid_partition(None, k, y[:, i, ...], xp, axis=0) + elif y.ndim != 1: + assert axis is not None + axis = axis - 1 if axis != -1 else -1 + assert isinstance(y.shape[0], int) + for i in range(y.shape[0]): + cls._assert_valid_partition(None, k, y[i, ...], xp, axis=axis) + else: + if k > 0: + assert xp.max(y[:k]) <= y[k] + assert y[k] <= xp.min(y[k:]) + + @classmethod + def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1): + return partition(xp.asarray(x), k, axis=axis) + + def _test_1d(self, xp: ModuleType): + rng = np.random.default_rng() + for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]: + k = int(rng.integers(n)) + x1 = rng.integers(n, size=n) + y = self._partition(x1, k, xp) + self._assert_valid_partition(x1, k, y, xp) + x2 = rng.random(n) + y = self._partition(x2, k, xp) + self._assert_valid_partition(x2, k, y, xp) + + def _test_nd(self, xp: ModuleType, ndim: int): + rng = np.random.default_rng() + + for n in [2, 3, 5, 10, 20, 100]: + base_shape = [int(v) for v in rng.integers(1, 4, size=ndim)] + k = int(rng.integers(n)) + + for i in range(ndim): + shape = base_shape[:] + shape[i] = n + x = rng.integers(n, size=tuple(shape)) + y = self._partition(x, k, xp, axis=i) + self._assert_valid_partition(x, k, y, xp, axis=i) + + z = rng.random(tuple(base_shape)) + k = int(rng.integers(z.size)) + y = self._partition(z, k, xp, axis=None) + self._assert_valid_partition(z, k, y, xp, axis=None) + + def _test_input_validation(self, xp: ModuleType): + with pytest.raises(TypeError): + _ = self._partition(np.asarray(1), 1, xp) + with pytest.raises(ValueError, match="out of bounds"): + _ = self._partition(np.asarray([1, 2]), 3, xp) + + def test_1d(self, xp: ModuleType): + self._test_1d(xp) + + @pytest.mark.parametrize("ndim", [2, 3, 4]) + def test_nd(self, xp: ModuleType, ndim: int): + self._test_nd(xp, ndim) + + def test_input_validation(self, xp: ModuleType): + self._test_input_validation(xp) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort") +class TestArgpartition(TestPartition): + @classmethod + @override + def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1): + arr = xp.asarray(x) + indices = argpartition(arr, k, axis=axis) + if axis is None: + arr = xp.reshape(arr, shape=(-1,)) + return arr[indices] + if arr.ndim == 1: + return arr[indices] + return cls._take_along_axis(arr, indices, axis=axis, xp=xp) + + @classmethod + def _take_along_axis(cls, arr: Array, indices: Array, axis: int, xp: ModuleType): + if hasattr(xp, "take_along_axis"): + return xp.take_along_axis(arr, indices, axis=axis) + if arr.ndim == 1: + return arr[indices] + if axis == 0: + assert isinstance(arr.shape[1], int) + arrs = [] + for i in range(arr.shape[1]): + arrs.append( + cls._take_along_axis( + arr[:, i, ...], indices[:, i, ...], axis=0, xp=xp + ) + ) + return xp.stack(arrs, axis=1) + axis = axis - 1 if axis != -1 else -1 + assert isinstance(arr.shape[0], int) + arrs = [] + for i in range(arr.shape[0]): + arrs.append( + cls._take_along_axis(arr[i, ...], indices[i, ...], axis=axis, xp=xp) + ) + return xp.stack(arrs, axis=0) + + @override + def test_1d(self, xp: ModuleType): + self._test_1d(xp) + + @pytest.mark.parametrize("ndim", [2, 3, 4]) + @override + def test_nd(self, xp: ModuleType, ndim: int): + self._test_nd(xp, ndim) + + @override + def test_input_validation(self, xp: ModuleType): + self._test_input_validation(xp) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 77ba8cd8..74ad3a19 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -182,11 +182,14 @@ def test_eager_shape(xp: ModuleType, library: Backend): # Lazy arrays, like Dask, have an eager shape until you slice them with # a lazy boolean mask assert eager_shape(a) == a.shape == (3,) + assert eager_shape(a, axis=0) == a.shape == (3,) b = a[a > 2] if library is Backend.DASK: with pytest.raises(TypeError, match="Unsupported lazy shape"): _ = eager_shape(b) + with pytest.raises(TypeError, match="Unsupported lazy shape"): + _ = eager_shape(b, axis=0) # FIXME can't test use case for None in the shape until we add support for # other lazy backends else: