Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions spec/draft/API_specification/type_promotion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ Notes
.. note::
Mixed integer and floating-point type promotion rules are not specified because behavior varies between implementations.


.. _mixing-scalars-and-arrays:

Mixing arrays with Python scalars
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
24 changes: 18 additions & 6 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]


from ._types import Optional, Tuple, Literal, array
from ._types import Optional, Tuple, Literal, Union, array


def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -139,21 +139,33 @@ def searchsorted(
"""


def where(condition: array, x1: array, x2: array, /) -> array:
def where(
condition: array,
x1: Union[array, int, float, bool],
x2: Union[array, int, float, bool],
/,
) -> array:
"""
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.

Parameters
----------
condition: array
when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`).
x1: array
first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
x2: array
second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
x1: Union[array, int, float, complex, bool]
first input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
x2: Union[array, int, float, complex, bool]
second input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).

Returns
-------
out: array
an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``.

Notes
-----
See :ref:`mixing-scalars-and-arrays` on compatibility requirements and handling of scalar arguments for ``x1`` and ``x2``.

.. versionchanged:: 2024.12
``x1`` and ``x2`` may be scalars.
"""