diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 2e512fc8..3858b9aa 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,7 @@ from __future__ import annotations from builtins import bool as py_bool +from typing import Literal import cupy as cp @@ -139,6 +140,24 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) +def searchsorted( + x1: Array, + x2: Array | int | float, + /, + *, + side: Literal['left', 'right'] = 'left', + sorter: Array | None = None +) -> Array: + # Match https://github.com/cupy/cupy/pull/9512/ until cupy v14 is the minimum + # supported version + if not isinstance(x2, cp.ndarray): + if not isinstance(x2, int | float | complex): + raise NotImplementedError( + 'Only python scalars or ndarrays are supported for x2') + x2 = cp.asarray(x2) + return cp.searchsorted(x1, x2, side, sorter) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -161,7 +180,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'ceil', 'floor', 'trunc', 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis', + 'searchsorted', +] def __dir__() -> list[str]: