Skip to content

Commit 6ebcdd9

Browse files
committed
Allow scalar arguments to where()
1 parent 6d205d7 commit 6ebcdd9

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/array_api_stubs/_draft/searching_functions.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]
22

33

4-
from ._types import Optional, Tuple, Literal, array
4+
from ._types import Optional, Tuple, Literal, Union, array
55

66

77
def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
@@ -139,21 +139,24 @@ def searchsorted(
139139
"""
140140

141141

142-
def where(condition: array, x1: array, x2: array, /) -> array:
142+
def where(condition: array, x1: Union[array, int, float, bool], x2: Union[array, int, float, bool], /) -> array:
143143
"""
144144
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
145145
146146
Parameters
147147
----------
148148
condition: array
149-
when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`).
150-
x1: array
151-
first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
152-
x2: array
153-
second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
149+
when ``True``, yield ``x1_i`` (scalar ``x1``); otherwise, yield ``x2_i`` (scalar ``x2``). Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`).
150+
x1: Union[array, int, float, bool]
151+
first input array or scalar. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
152+
x2: Union[array, int, float, bool]
153+
second input array or scalar. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
154154
155155
Returns
156156
-------
157157
out: array
158158
an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``.
159+
160+
.. versionchanged:: 2024.12
161+
``x1`` and ``x2`` may be scalars.
159162
"""

0 commit comments

Comments
 (0)