Skip to content

Commit 0aa20f5

Browse files
committed
2024.12 scalars
1 parent 8966397 commit 0aa20f5

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
161161
temp1 = f1(*(arr[cond] for arr in args))
162162

163163
if f2 is None:
164-
# TODO remove asarrays once all backends support Array API 2024.12
165-
dtype = xp.result_type(*asarrays(temp1, fill_value, xp=xp))
164+
dtype = xp.result_type(temp1, fill_value)
166165
if getattr(fill_value, "ndim", 0):
167166
out = xp.astype(fill_value, dtype, copy=True)
168167
else:

tests/test_funcs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,11 @@ def f2(*args: Array) -> Array:
256256

257257
ref1 = xp.where(cond, f1(*arrays), fill_value)
258258
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))
259-
# TODO remove asarrays once all backends support Array API 2024.12
260-
ref3 = xp.where(cond, *asarrays(f1(*arrays), float_fill_value, xp=xp))
259+
if library is Backend.ARRAY_API_STRICT:
260+
# FIXME https://github.com/data-apis/array-api-strict/issues/131
261+
ref3 = xp.where(cond, *asarrays(f1(*arrays), float_fill_value, xp=xp))
262+
else:
263+
ref3 = xp.where(cond, f1(*arrays), float_fill_value)
261264

262265
xp_assert_close(res1, ref1, rtol=2e-16)
263266
xp_assert_equal(res2, ref2)

0 commit comments

Comments
 (0)