|
11 | 11 | from array_api_extra._lib._at import _AtOp
|
12 | 12 | from array_api_extra._lib._backends import Backend
|
13 | 13 | from array_api_extra._lib._testing import xp_assert_equal
|
14 |
| -from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array |
| 14 | +from array_api_extra._lib._utils._compat import ( |
| 15 | + array_namespace, |
| 16 | + is_jax_namespace, |
| 17 | + is_writeable_array, |
| 18 | +) |
15 | 19 | from array_api_extra._lib._utils._compat import device as get_device
|
16 | 20 | from array_api_extra._lib._utils._typing import Array, Device, SetIndex
|
17 | 21 | from array_api_extra.testing import lazy_xp_function
|
@@ -272,6 +276,38 @@ def test_bool_mask_nd(xp: ModuleType):
|
272 | 276 | xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
|
273 | 277 |
|
274 | 278 |
|
| 279 | +def test_setitem_int_array_index(xp: ModuleType): |
| 280 | + # Single dimension |
| 281 | + x = xp.asarray([0.0, 1.0, 2.0]) |
| 282 | + y = xp.asarray([3.0, 4.0]) |
| 283 | + idx = xp.asarray([0, 2]) |
| 284 | + expect = xp.asarray([3.0, 1.0, 4.0]) |
| 285 | + z = at_op(x, idx, _AtOp.SET, y) |
| 286 | + assert isinstance(z, type(x)) |
| 287 | + xp_assert_equal(z, expect) |
| 288 | + # Single dimension, non-unique index |
| 289 | + x = xp.asarray([0.0, 1.0]) |
| 290 | + y = xp.asarray([2.0, 3.0]) |
| 291 | + idx = xp.asarray([1, 1]) |
| 292 | + device_str = str(get_device(x)).lower() |
| 293 | + # GPU arrays generally use the first element, but JAX with float64 enabled uses the |
| 294 | + # last element. |
| 295 | + if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp): |
| 296 | + expect = xp.asarray([0.0, 2.0]) |
| 297 | + else: |
| 298 | + expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last |
| 299 | + z = at_op(x, idx, _AtOp.SET, y) |
| 300 | + assert isinstance(z, type(x)) |
| 301 | + xp_assert_equal(z, expect) |
| 302 | + # Multiple dimensions |
| 303 | + x = xp.asarray([[0.0, 1.0], [2.0, 3.0]]) |
| 304 | + y = xp.asarray([[4.0, 5.0]]) |
| 305 | + idx = xp.asarray([0]) |
| 306 | + expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]]) |
| 307 | + z = at_op(x, idx, _AtOp.SET, y) |
| 308 | + xp_assert_equal(z, expect) |
| 309 | + |
| 310 | + |
275 | 311 | @pytest.mark.parametrize("bool_mask", [False, True])
|
276 | 312 | def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
|
277 | 313 | x = xp.asarray([math.inf, 1.0, 2.0])
|
|
0 commit comments