Skip to content

Commit f11c85a

Browse files
committed
Adapt to draft_version
1 parent d9a6b73 commit f11c85a

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

array_api_strict/_searching_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | flo
8989
_result_type(x1.dtype, x2.dtype)
9090

9191
if len({a.device for a in (condition, x1, x2)}) > 1:
92-
raise ValueError("where inputs must all be on the same device")
92+
raise ValueError("Inputs to `where` must all use the same device")
9393

9494
x1, x2 = Array._normalize_two_args(x1, x2)
9595
return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device)

array_api_strict/tests/test_searching_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import array_api_strict as xp
44

55
from array_api_strict import ArrayAPIStrictFlags
6-
from array_api_strict._flags import next_supported_version
6+
from array_api_strict._flags import draft_version
77

88

99
def test_where_with_scalars():
@@ -14,7 +14,7 @@ def test_where_with_scalars():
1414
xp.where(x == 1, 42, 44)
1515

1616
# Versions after 2023.12 support scalar arguments
17-
with ArrayAPIStrictFlags(api_version=next_supported_version):
17+
with ArrayAPIStrictFlags(api_version=draft_version):
1818
x_where = xp.where(x == 1, 42, 44)
1919

2020
expected = xp.asarray([42, 44, 44, 42])

0 commit comments

Comments
 (0)