@@ -90,18 +90,29 @@ def searchsorted(
90
90
# x1 must be 1-D, but NumPy already requires this.
91
91
return Array ._new (np .searchsorted (x1 ._array , x2 ._array , side = side , sorter = sorter ), device = x1 .device )
92
92
93
- def where (condition : Array , x1 : bool | int | float | Array , x2 : bool | int | float | Array , / ) -> Array :
93
+ def where (
94
+ condition : Array ,
95
+ x1 : bool | int | float | complex | Array ,
96
+ x2 : bool | int | float | complex | Array , /
97
+ ) -> Array :
94
98
"""
95
99
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
96
100
97
101
See its docstring for more information.
98
102
"""
99
103
if get_array_api_strict_flags ()['api_version' ] > '2023.12' :
100
- if isinstance (x1 , (bool , float , int )):
104
+ num_scalars = 0
105
+
106
+ if isinstance (x1 , (bool , float , complex , int )):
101
107
x1 = Array ._new (np .asarray (x1 ), device = condition .device )
108
+ num_scalars += 1
102
109
103
- if isinstance (x2 , (bool , float , int )):
110
+ if isinstance (x2 , (bool , float , complex , int )):
104
111
x2 = Array ._new (np .asarray (x2 ), device = condition .device )
112
+ num_scalars += 1
113
+
114
+ if num_scalars == 2 :
115
+ raise ValueError ("One of x1, x2 arguments must be an array." )
105
116
106
117
# Call result type here just to raise on disallowed type combinations
107
118
_result_type (x1 .dtype , x2 .dtype )
0 commit comments