|
1 | 1 | __all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]
|
2 | 2 |
|
3 | 3 |
|
4 |
| -from ._types import Optional, Tuple, Literal, array |
| 4 | +from ._types import Optional, Tuple, Literal, Union, array |
5 | 5 |
|
6 | 6 |
|
7 | 7 | def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
|
@@ -139,21 +139,24 @@ def searchsorted(
|
139 | 139 | """
|
140 | 140 |
|
141 | 141 |
|
142 |
| -def where(condition: array, x1: array, x2: array, /) -> array: |
| 142 | +def where(condition: array, x1: Union[array, int, float, bool], x2: Union[array, int, float, bool], /) -> array: |
143 | 143 | """
|
144 | 144 | Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
|
145 | 145 |
|
146 | 146 | Parameters
|
147 | 147 | ----------
|
148 | 148 | condition: array
|
149 |
| - when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). |
150 |
| - x1: array |
151 |
| - first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). |
152 |
| - x2: array |
153 |
| - second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). |
| 149 | + when ``True``, yield ``x1_i`` (scalar ``x1``); otherwise, yield ``x2_i`` (scalar ``x2``). Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`). |
| 150 | + x1: Union[array, int, float, bool] |
| 151 | + first input array or scalar. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`). |
| 152 | + x2: Union[array, int, float, bool] |
| 153 | + second input array or scalar. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`). |
154 | 154 |
|
155 | 155 | Returns
|
156 | 156 | -------
|
157 | 157 | out: array
|
158 | 158 | an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``.
|
| 159 | +
|
| 160 | + .. versionchanged:: 2024.12 |
| 161 | + ``x1`` and ``x2`` may be scalars. |
159 | 162 | """
|
0 commit comments