|
16 | 16 | from array_api_extra._lib._utils._typing import Array, Index |
17 | 17 | from array_api_extra.testing import lazy_xp_function |
18 | 18 |
|
| 19 | +pytestmark = [ |
| 20 | + pytest.mark.skip_xp_backend( |
| 21 | + Backend.SPARSE, reason="read-only backend without .at support" |
| 22 | + ) |
| 23 | +] |
| 24 | + |
19 | 25 |
|
20 | 26 | def at_op( # type: ignore[no-any-explicit] |
21 | 27 | x: Array, |
@@ -71,9 +77,6 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: |
71 | 77 | xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy)) |
72 | 78 |
|
73 | 79 |
|
74 | | -@pytest.mark.skip_xp_backend( |
75 | | - Backend.SPARSE, reason="read-only backend without .at support" |
76 | | -) |
77 | 80 | @pytest.mark.parametrize( |
78 | 81 | ("kwargs", "expect_copy"), |
79 | 82 | [ |
@@ -170,78 +173,70 @@ def test_alternate_index_syntax(): |
170 | 173 | at(a, 0)[0].set(4) |
171 | 174 |
|
172 | 175 |
|
173 | | -@pytest.mark.skip_xp_backend( |
174 | | - Backend.SPARSE, reason="read-only backend without .at support" |
175 | | -) |
176 | 176 | @pytest.mark.parametrize("copy", [True, None]) |
177 | | -@pytest.mark.parametrize( |
178 | | - "op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER] |
179 | | -) |
180 | | -def test_iops_incompatible_dtype( |
181 | | - xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None |
| 177 | +@pytest.mark.parametrize("bool_mask", [False, True]) |
| 178 | +@pytest.mark.parametrize("op", list(_AtOp)) |
| 179 | +def test_incompatible_dtype( |
| 180 | + xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool |
182 | 181 | ): |
183 | 182 | """Test that at() replicates the backend's behaviour for |
184 | 183 | in-place operations with incompatible dtypes. |
185 | 184 |
|
186 | | - Note: |
| 185 | + Behavior is backend-specific, but only two behaviors are allowed: |
| 186 | + 1. raise an exception, or |
| 187 | + 2. return the same dtype as x, disregarding y.dtype (no broadcasting). |
| 188 | +
|
| 189 | + Note that __i<op>__ and __<op>__ behave differently, and we want to |
| 190 | + replicate the behavior of __i<op>__: |
| 191 | +
|
187 | 192 | >>> a = np.asarray([1, 2, 3]) |
188 | 193 | >>> a / 1.5 |
189 | 194 | array([0. , 0.66666667, 1.33333333]) |
190 | 195 | >>> a /= 1.5 |
191 | 196 | UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64') |
192 | 197 | to dtype('int64') with casting rule 'same_kind' |
| 198 | +
|
| 199 | + See Also |
| 200 | + -------- |
193 | 201 | """ |
194 | 202 | x = xp.asarray([2, 4]) |
195 | | - |
196 | | - if library is Backend.DASK: |
197 | | - z = at_op(x, slice(None), op, 1.1, copy=copy) |
198 | | - assert z.dtype == x.dtype |
199 | | - |
200 | | - elif library is Backend.JAX: |
201 | | - with pytest.warns(FutureWarning, match="cannot safely cast"): |
202 | | - z = at_op(x, slice(None), op, 1.1, copy=copy) |
203 | | - assert z.dtype == x.dtype |
204 | | - |
205 | | - else: |
| 203 | + idx = xp.asarray([True, False]) if bool_mask else slice(None) |
| 204 | + z = None |
| 205 | + |
| 206 | + if library is Backend.JAX: |
| 207 | + if bool_mask: |
| 208 | + z = at_op(x, idx, op, 1.1, copy=copy) |
| 209 | + else: |
| 210 | + with pytest.warns(FutureWarning, match="cannot safely cast"): |
| 211 | + z = at_op(x, idx, op, 1.1, copy=copy) |
| 212 | + |
| 213 | + elif library is Backend.DASK: |
| 214 | + if op in (_AtOp.MIN, _AtOp.MAX): |
| 215 | + pytest.xfail(reason="need array-api-compat 1.11") |
| 216 | + z = at_op(x, idx, op, 1.1, copy=copy) |
| 217 | + |
| 218 | + elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET: |
206 | 219 | with pytest.raises(Exception, match=r"cast|promote|dtype"): |
207 | | - at_op(x, slice(None), op, 1.1, copy=copy) |
| 220 | + at_op(x, idx, op, 1.1, copy=copy) |
208 | 221 |
|
209 | | - |
210 | | -@pytest.mark.skip_xp_backend( |
211 | | - Backend.SPARSE, reason="read-only backend without .at support" |
212 | | -) |
213 | | -@pytest.mark.parametrize( |
214 | | - "op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER] |
215 | | -) |
216 | | -def test_bool_mask_incompatible_dtype(xp: ModuleType, library: Backend, op: _AtOp): |
217 | | - """ |
218 | | - When xp.where(idx, y, x) would promote the dtype of the output |
219 | | - to y.dtype, at(x, idx).set(y) must retain x.dtype instead |
220 | | - """ |
221 | | - x = xp.asarray([1, 2]) |
222 | | - idx = xp.asarray([True, False]) |
223 | | - if library in (Backend.DASK, Backend.JAX): |
224 | | - z = at_op(x, idx, op, 1.1) |
225 | | - assert z.dtype == x.dtype |
| 222 | + elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX): |
| 223 | + # There is no __i<op>__ version of these operations |
| 224 | + z = at_op(x, idx, op, 1.1, copy=copy) |
226 | 225 |
|
227 | 226 | else: |
228 | 227 | with pytest.raises(Exception, match=r"cast|promote|dtype"): |
229 | | - at_op(x, idx, op, 1.1) |
| 228 | + at_op(x, idx, op, 1.1, copy=copy) |
| 229 | + |
| 230 | + assert z is None or z.dtype == x.dtype |
230 | 231 |
|
231 | 232 |
|
232 | | -@pytest.mark.skip_xp_backend( |
233 | | - Backend.SPARSE, reason="read-only backend without .at support" |
234 | | -) |
235 | 233 | def test_bool_mask_nd(xp: ModuleType): |
236 | 234 | x = xp.asarray([[1, 2, 3], [4, 5, 6]]) |
237 | 235 | idx = xp.asarray([[True, False, False], [False, True, True]]) |
238 | 236 | z = at_op(x, idx, _AtOp.SET, 0) |
239 | 237 | xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]])) |
240 | 238 |
|
241 | 239 |
|
242 | | -@pytest.mark.skip_xp_backend( |
243 | | - Backend.SPARSE, reason="read-only backend without .at support" |
244 | | -) |
245 | 240 | @pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere") |
246 | 241 | @pytest.mark.parametrize("bool_mask", [False, True]) |
247 | 242 | def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): |
|
0 commit comments