Skip to content

Commit a990846

Browse files
committed
rework incompatible_dtype test
1 parent eb6b721 commit a990846

File tree

1 file changed

+43
-48
lines changed

1 file changed

+43
-48
lines changed

tests/test_at.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
from array_api_extra._lib._utils._typing import Array, Index
1717
from array_api_extra.testing import lazy_xp_function
1818

19+
pytestmark = [
20+
pytest.mark.skip_xp_backend(
21+
Backend.SPARSE, reason="read-only backend without .at support"
22+
)
23+
]
24+
1925

2026
def at_op( # type: ignore[no-any-explicit]
2127
x: Array,
@@ -71,9 +77,6 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7177
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
7278

7379

74-
@pytest.mark.skip_xp_backend(
75-
Backend.SPARSE, reason="read-only backend without .at support"
76-
)
7780
@pytest.mark.parametrize(
7881
("kwargs", "expect_copy"),
7982
[
@@ -170,78 +173,70 @@ def test_alternate_index_syntax():
170173
at(a, 0)[0].set(4)
171174

172175

173-
@pytest.mark.skip_xp_backend(
174-
Backend.SPARSE, reason="read-only backend without .at support"
175-
)
176176
@pytest.mark.parametrize("copy", [True, None])
177-
@pytest.mark.parametrize(
178-
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
179-
)
180-
def test_iops_incompatible_dtype(
181-
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None
177+
@pytest.mark.parametrize("bool_mask", [False, True])
178+
@pytest.mark.parametrize("op", list(_AtOp))
179+
def test_incompatible_dtype(
180+
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool
182181
):
183182
"""Test that at() replicates the backend's behaviour for
184183
in-place operations with incompatible dtypes.
185184
186-
Note:
185+
Behavior is backend-specific, but only two behaviors are allowed:
186+
1. raise an exception, or
187+
2. return the same dtype as x, disregarding y.dtype (no broadcasting).
188+
189+
Note that __i<op>__ and __<op>__ behave differently, and we want to
190+
replicate the behavior of __i<op>__:
191+
187192
>>> a = np.asarray([1, 2, 3])
188193
>>> a / 1.5
189194
array([0. , 0.66666667, 1.33333333])
190195
>>> a /= 1.5
191196
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
192197
to dtype('int64') with casting rule 'same_kind'
198+
199+
See Also
200+
--------
193201
"""
194202
x = xp.asarray([2, 4])
195-
196-
if library is Backend.DASK:
197-
z = at_op(x, slice(None), op, 1.1, copy=copy)
198-
assert z.dtype == x.dtype
199-
200-
elif library is Backend.JAX:
201-
with pytest.warns(FutureWarning, match="cannot safely cast"):
202-
z = at_op(x, slice(None), op, 1.1, copy=copy)
203-
assert z.dtype == x.dtype
204-
205-
else:
203+
idx = xp.asarray([True, False]) if bool_mask else slice(None)
204+
z = None
205+
206+
if library is Backend.JAX:
207+
if bool_mask:
208+
z = at_op(x, idx, op, 1.1, copy=copy)
209+
else:
210+
with pytest.warns(FutureWarning, match="cannot safely cast"):
211+
z = at_op(x, idx, op, 1.1, copy=copy)
212+
213+
elif library is Backend.DASK:
214+
if op in (_AtOp.MIN, _AtOp.MAX):
215+
pytest.xfail(reason="need array-api-compat 1.11")
216+
z = at_op(x, idx, op, 1.1, copy=copy)
217+
218+
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
206219
with pytest.raises(Exception, match=r"cast|promote|dtype"):
207-
at_op(x, slice(None), op, 1.1, copy=copy)
220+
at_op(x, idx, op, 1.1, copy=copy)
208221

209-
210-
@pytest.mark.skip_xp_backend(
211-
Backend.SPARSE, reason="read-only backend without .at support"
212-
)
213-
@pytest.mark.parametrize(
214-
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
215-
)
216-
def test_bool_mask_incompatible_dtype(xp: ModuleType, library: Backend, op: _AtOp):
217-
"""
218-
When xp.where(idx, y, x) would promote the dtype of the output
219-
to y.dtype, at(x, idx).set(y) must retain x.dtype instead
220-
"""
221-
x = xp.asarray([1, 2])
222-
idx = xp.asarray([True, False])
223-
if library in (Backend.DASK, Backend.JAX):
224-
z = at_op(x, idx, op, 1.1)
225-
assert z.dtype == x.dtype
222+
elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX):
223+
# There is no __i<op>__ version of these operations
224+
z = at_op(x, idx, op, 1.1, copy=copy)
226225

227226
else:
228227
with pytest.raises(Exception, match=r"cast|promote|dtype"):
229-
at_op(x, idx, op, 1.1)
228+
at_op(x, idx, op, 1.1, copy=copy)
229+
230+
assert z is None or z.dtype == x.dtype
230231

231232

232-
@pytest.mark.skip_xp_backend(
233-
Backend.SPARSE, reason="read-only backend without .at support"
234-
)
235233
def test_bool_mask_nd(xp: ModuleType):
236234
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
237235
idx = xp.asarray([[True, False, False], [False, True, True]])
238236
z = at_op(x, idx, _AtOp.SET, 0)
239237
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
240238

241239

242-
@pytest.mark.skip_xp_backend(
243-
Backend.SPARSE, reason="read-only backend without .at support"
244-
)
245240
@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
246241
@pytest.mark.parametrize("bool_mask", [False, True])
247242
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):

0 commit comments

Comments
 (0)