Skip to content

Commit 0bcbf80

Browse files
authored
Merge pull request #1 from ev-br/pr/78
2 parents 22d4fc0 + e0a8d7a commit 0bcbf80

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

array_api_strict/_searching_functions.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,29 @@ def searchsorted(
9090
# x1 must be 1-D, but NumPy already requires this.
9191
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
9292

93-
def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array:
93+
def where(
94+
condition: Array,
95+
x1: bool | int | float | complex | Array,
96+
x2: bool | int | float | complex | Array, /
97+
) -> Array:
9498
"""
9599
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
96100
97101
See its docstring for more information.
98102
"""
99103
if get_array_api_strict_flags()['api_version'] > '2023.12':
100-
if isinstance(x1, (bool, float, int)):
104+
num_scalars = 0
105+
106+
if isinstance(x1, (bool, float, complex, int)):
101107
x1 = Array._new(np.asarray(x1), device=condition.device)
108+
num_scalars += 1
102109

103-
if isinstance(x2, (bool, float, int)):
110+
if isinstance(x2, (bool, float, complex, int)):
104111
x2 = Array._new(np.asarray(x2), device=condition.device)
112+
num_scalars += 1
113+
114+
if num_scalars == 2:
115+
raise ValueError("One of x1, x2 arguments must be an array.")
105116

106117
# Call result type here just to raise on disallowed type combinations
107118
_result_type(x1.dtype, x2.dtype)

array_api_strict/tests/test_searching_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ def test_where_with_scalars():
2020
),
2121
ArrayAPIStrictFlags(api_version=draft_version),
2222
):
23-
x_where = xp.where(x == 1, 42, 44)
23+
x_where = xp.where(x == 1, xp.asarray(42), 44)
2424

2525
expected = xp.asarray([42, 44, 44, 42])
2626
assert xp.all(x_where == expected)
27+
28+
# The spec does not allow both x1 and x2 to be scalars
29+
with pytest.raises(ValueError, match="One of"):
30+
xp.where(x == 1, 42, 44)

0 commit comments

Comments
 (0)