From 7f42065b945f013833e1ce5d7aad20c1e97fed94 Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Thu, 21 Aug 2025 02:49:47 +0200 Subject: [PATCH 1/6] Add fancy __setitem__ support for array-api-strict --- src/array_api_extra/_lib/_at.py | 29 +++++++++++++++++++++++++ tests/test_at.py | 38 ++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index fb2d6ab7..92e90a41 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -11,6 +11,8 @@ from ._utils import _compat from ._utils._compat import ( array_namespace, + is_array_api_obj, + is_array_api_strict_namespace, is_dask_array, is_jax_array, is_torch_array, @@ -352,6 +354,33 @@ def _op( if is_torch_array(y): y = xp.astype(y, x.dtype, copy=False) + # Work around lack of fancy indexing __setitem__ support in array-api-strict. + if ( + is_array_api_strict_namespace(xp) + and is_array_api_obj(idx) + and xp.isdtype(idx.dtype, "integral") + and out_of_place_op is None # only use for set() + ): + # Vectorize the operation using boolean indexing + # For non-unique indices, take the last occurrence. This requires creating + # masks for x and y that create matching shapes. + unique_indices, first_occurrence_mask = xp.unique_inverse(idx) + x_mask = xp.any(xp.arange(x.shape[0])[..., None] == unique_indices, axis=-1) + # Get last occurrence of each unique index + cmp = unique_indices[:, None] == unique_indices[None, :] + # Ignore later matches + lower_tri_mask = ( + xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :] + ) + masked_cmp = cmp & lower_tri_mask + # For each position i, count how many matches occurred before i + prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) + # Last occurrence has highest match count + y_mask = prior_matches == xp.max(prior_matches, axis=-1) + # Apply the operation only to last occurrences + x[x_mask] = y[y_mask] + return x + # Backends without boolean indexing (other than JAX) crash here if in_place_op: # add(), subtract(), ... x[idx] = in_place_op(x[idx], y) diff --git a/tests/test_at.py b/tests/test_at.py index 9558f7b8..ec15788b 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -11,7 +11,11 @@ from array_api_extra._lib._at import _AtOp from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_equal -from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array +from array_api_extra._lib._utils._compat import ( + array_namespace, + is_jax_namespace, + is_writeable_array, +) from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._typing import Array, Device, SetIndex from array_api_extra.testing import lazy_xp_function @@ -272,6 +276,38 @@ def test_bool_mask_nd(xp: ModuleType): xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]])) +def test_setitem_int_array_index(xp: ModuleType): + # Single dimension + x = xp.asarray([0.0, 1.0, 2.0]) + y = xp.asarray([3.0, 4.0]) + idx = xp.asarray([0, 2]) + expect = xp.asarray([3.0, 1.0, 4.0]) + z = at_op(x, idx, _AtOp.SET, y) + assert isinstance(z, type(x)) + xp_assert_equal(z, expect) + # Single dimension, non-unique index + x = xp.asarray([0.0, 1.0]) + y = xp.asarray([2.0, 3.0]) + idx = xp.asarray([1, 1]) + device_str = str(get_device(x)).lower() + # GPU arrays generally use the first element, but JAX with float64 enabled uses the + # last element. + if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp): + expect = xp.asarray([0.0, 2.0]) + else: + expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last + z = at_op(x, idx, _AtOp.SET, y) + assert isinstance(z, type(x)) + xp_assert_equal(z, expect) + # Multiple dimensions + x = xp.asarray([[0.0, 1.0], [2.0, 3.0]]) + y = xp.asarray([[4.0, 5.0]]) + idx = xp.asarray([0]) + expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]]) + z = at_op(x, idx, _AtOp.SET, y) + xp_assert_equal(z, expect) + + @pytest.mark.parametrize("bool_mask", [False, True]) def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): x = xp.asarray([math.inf, 1.0, 2.0]) From 9df1c6520cc390aa15635093de0863574c95c034 Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Thu, 21 Aug 2025 11:21:26 +0200 Subject: [PATCH 2/6] Remove unnecessary variable assignment --- src/array_api_extra/_lib/_at.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 92e90a41..d79f552e 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -364,7 +364,7 @@ def _op( # Vectorize the operation using boolean indexing # For non-unique indices, take the last occurrence. This requires creating # masks for x and y that create matching shapes. - unique_indices, first_occurrence_mask = xp.unique_inverse(idx) + unique_indices, _ = xp.unique_inverse(idx) x_mask = xp.any(xp.arange(x.shape[0])[..., None] == unique_indices, axis=-1) # Get last occurrence of each unique index cmp = unique_indices[:, None] == unique_indices[None, :] From 2c5e0aa8085b7a33b7608ff62da49d62388f90a1 Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Sat, 23 Aug 2025 23:25:37 +0200 Subject: [PATCH 3/6] Add docs. Fix 0D and scalar cases. Handle negative indices --- src/array_api_extra/_lib/_at.py | 82 +++++++++++++++++++++------------ tests/test_at.py | 12 +++++ 2 files changed, 65 insertions(+), 29 deletions(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index d79f552e..a889551d 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -150,6 +150,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 >>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1) Array([125], dtype=int32) + For frameworks that don't support fancy indexing by default, e.g. array-api-strict, + we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments + with multiple occurences of the same index always choose the last occurence. This is + consistent with numpy's behaviour, e.g.:: + + >>> import numpy as np + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3])) + array([3]) + >>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3])) + Array([3], dtype=array_api_strict.int64) + See Also -------- jax.numpy.ndarray.at : Equivalent array method in JAX. @@ -354,39 +367,50 @@ def _op( if is_torch_array(y): y = xp.astype(y, x.dtype, copy=False) - # Work around lack of fancy indexing __setitem__ support in array-api-strict. - if ( - is_array_api_strict_namespace(xp) - and is_array_api_obj(idx) - and xp.isdtype(idx.dtype, "integral") - and out_of_place_op is None # only use for set() - ): - # Vectorize the operation using boolean indexing - # For non-unique indices, take the last occurrence. This requires creating - # masks for x and y that create matching shapes. - unique_indices, _ = xp.unique_inverse(idx) - x_mask = xp.any(xp.arange(x.shape[0])[..., None] == unique_indices, axis=-1) - # Get last occurrence of each unique index - cmp = unique_indices[:, None] == unique_indices[None, :] - # Ignore later matches - lower_tri_mask = ( - xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :] - ) - masked_cmp = cmp & lower_tri_mask - # For each position i, count how many matches occurred before i - prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) - # Last occurrence has highest match count - y_mask = prior_matches == xp.max(prior_matches, axis=-1) - # Apply the operation only to last occurrences - x[x_mask] = y[y_mask] - return x - # Backends without boolean indexing (other than JAX) crash here if in_place_op: # add(), subtract(), ... x[idx] = in_place_op(x[idx], y) - else: # set() + return x + # set() + try: # We first try to use the backend's __setitem__ if available x[idx] = y - return x + return x + except IndexError as e: + if "Fancy indexing" not in str(e): # Avoid masking other index errors + raise e + # Work around lack of fancy indexing __setitem__ + if ( + is_array_api_obj(idx) + and xp.isdtype(idx.dtype, "integral") + and idx.ndim == 1 + ): + # Vectorize the operation using boolean indexing + # For non-unique indices, take the last occurrence. This requires + # masks for x and y that create matching shapes. + # We first create the mask for x + u_idx, _ = xp.unique_inverse(idx) + # Convert negative indices to positive, otherwise they won't get matched + u_idx = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx) + x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx, axis=-1) + # If y is a scalar or 0D, we are done + if not is_array_api_obj(y) or y.ndim == 0: + x[x_mask] = y + return x + # If not, create a mask for y. Get last occurrence of each unique index + cmp = u_idx[:, None] == u_idx[None, :] + # Ignore later matches + lower_tri_mask = ( + xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :] + ) + masked_cmp = cmp & lower_tri_mask + # For each position i, count how many matches occurred before i + prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) + # Last occurrence has highest match count + y_mask = prior_matches == xp.max(prior_matches, axis=-1) + # Apply the operation only to last occurrences + x[x_mask] = y[y_mask] + return x + raise e def set( self, diff --git a/tests/test_at.py b/tests/test_at.py index ec15788b..17899ad9 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -306,6 +306,18 @@ def test_setitem_int_array_index(xp: ModuleType): expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]]) z = at_op(x, idx, _AtOp.SET, y) xp_assert_equal(z, expect) + # Scalar + x = xp.asarray([0.0, 1.0]) + z = at_op(x, xp.asarray([1]), _AtOp.SET, 2.0) + xp_assert_equal(z, xp.asarray([0.0, 2.0])) + # 0D array + x = xp.asarray([0.0, 1.0]) + z = at_op(x, xp.asarray([1]), _AtOp.SET, xp.asarray(2.0)) + xp_assert_equal(z, xp.asarray([0.0, 2.0])) + # Negative indices + x = xp.asarray([0.0, 1.0]) + z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0) + xp_assert_equal(z, xp.asarray([0.0, 2.0])) @pytest.mark.parametrize("bool_mask", [False, True]) From 6fe2ffb03cb6056a37db8c3849cc3b773faf78b0 Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Sun, 24 Aug 2025 13:04:30 +0200 Subject: [PATCH 4/6] Add handling and tests for negative and positive out-of-bounds indices --- src/array_api_extra/_lib/_at.py | 31 +++++++++++++++++++++++++++---- tests/test_at.py | 16 ++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index a889551d..6ef685e1 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -12,9 +12,9 @@ from ._utils._compat import ( array_namespace, is_array_api_obj, - is_array_api_strict_namespace, is_dask_array, is_jax_array, + is_lazy_array, is_torch_array, is_writeable_array, ) @@ -390,14 +390,37 @@ def _op( # We first create the mask for x u_idx, _ = xp.unique_inverse(idx) # Convert negative indices to positive, otherwise they won't get matched - u_idx = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx) - x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx, axis=-1) + u_idx_pos = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx) + # Check for out of bounds indices + oob_index = None + if is_lazy_array(u_idx_pos): + pass # Lazy arrays cannot check for out of bounds indices + elif xp.any(pos_oob := u_idx_pos >= x.shape[0]): + first_oob_idx = xp.argmax(xp.astype(pos_oob, xp.int32)) + oob_index = int(u_idx[first_oob_idx]) + elif xp.any(neg_oob := u_idx_pos < 0): + first_oob_idx = xp.argmax(xp.astype(neg_oob, xp.int32)) + oob_index = int(u_idx[first_oob_idx]) + if oob_index is not None: + msg = ( + f"index {oob_index} is out of bounds for array of " + f"shape {x.shape}" + ) + raise IndexError(msg) from e + + x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1) # If y is a scalar or 0D, we are done if not is_array_api_obj(y) or y.ndim == 0: x[x_mask] = y return x + if y.shape[0] != idx.shape[0]: + msg = ( + f"shape mismatch: value array of shape {y.shape} could not be " + f"broadcast to indexing result of shape {idx.shape}" + ) + raise ValueError(msg) from e # If not, create a mask for y. Get last occurrence of each unique index - cmp = u_idx[:, None] == u_idx[None, :] + cmp = u_idx_pos[:, None] == u_idx_pos[None, :] # Ignore later matches lower_tri_mask = ( xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :] diff --git a/tests/test_at.py b/tests/test_at.py index 17899ad9..531787a3 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -13,7 +13,9 @@ from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import ( array_namespace, + is_array_api_strict_namespace, is_jax_namespace, + is_numpy_namespace, is_writeable_array, ) from array_api_extra._lib._utils._compat import device as get_device @@ -318,6 +320,20 @@ def test_setitem_int_array_index(xp: ModuleType): x = xp.asarray([0.0, 1.0]) z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0) xp_assert_equal(z, xp.asarray([0.0, 2.0])) + # Different frameworks have all kinds of different behaviours for negative indices, + # out-of-bounds indices, etc. Therefore, we only test the behaviour of two + # frameworks: numpy because we state in the docs that it is our reference for the + # behaviour of other frameworks with no native support, and array-api-strict. + if is_array_api_strict_namespace(xp) or is_numpy_namespace(xp): + # Test wrong shapes + with pytest.raises(ValueError, match="shape"): + at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2])) + # Test positive out of bounds index + with pytest.raises(IndexError, match="out of bounds"): + at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1])) + # Test negative out of bounds index + with pytest.raises(IndexError, match="out of bounds"): + at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1])) @pytest.mark.parametrize("bool_mask", [False, True]) From acd536540cc8027835d2cbed57b4821868e6de9d Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Sun, 24 Aug 2025 16:30:37 +0200 Subject: [PATCH 5/6] Fix linting --- src/array_api_extra/_lib/_at.py | 4 ++-- tests/test_at.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 6ef685e1..77467f0e 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -152,8 +152,8 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 For frameworks that don't support fancy indexing by default, e.g. array-api-strict, we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments - with multiple occurences of the same index always choose the last occurence. This is - consistent with numpy's behaviour, e.g.:: + with multiple occurrences of the same index always choose the last occurrence. This + is consistent with numpy's behaviour, e.g.:: >>> import numpy as np >>> import array_api_strict as xp diff --git a/tests/test_at.py b/tests/test_at.py index 531787a3..da834383 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -327,13 +327,13 @@ def test_setitem_int_array_index(xp: ModuleType): if is_array_api_strict_namespace(xp) or is_numpy_namespace(xp): # Test wrong shapes with pytest.raises(ValueError, match="shape"): - at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2])) + _ = at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2])) # Test positive out of bounds index with pytest.raises(IndexError, match="out of bounds"): - at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1])) + _ = at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1])) # Test negative out of bounds index with pytest.raises(IndexError, match="out of bounds"): - at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1])) + _ = at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1])) @pytest.mark.parametrize("bool_mask", [False, True]) From 07e6bda45e5d4e1e0eb539d7880d168441af4e2c Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Thu, 28 Aug 2025 17:56:40 +0200 Subject: [PATCH 6/6] [wip] Fix _at logic. Update tests --- src/array_api_extra/_lib/_at.py | 104 +++++++++++++++----------------- tests/test_at.py | 10 +-- 2 files changed, 54 insertions(+), 60 deletions(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 77467f0e..e2ee4fa0 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -75,7 +75,8 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 idx : index, optional Only `array API standard compliant indices `_ - are supported. + are supported. The only exception are one-dimensional integer array indices + (not expressed as tuples) along the first axis for set() operations. You may use two alternate syntaxes:: @@ -150,10 +151,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 >>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1) Array([125], dtype=int32) - For frameworks that don't support fancy indexing by default, e.g. array-api-strict, - we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments - with multiple occurrences of the same index always choose the last occurrence. This - is consistent with numpy's behaviour, e.g.:: + The Array API standard does not support assignment by integer array, even if many + libraries like NumPy do. `xpx.at` works around lack of support by performing an + out-of-place operation. Assignments with multiple occurrences of the same index + always choose the last occurrence. This is consistent with NumPy's behaviour. >>> import numpy as np >>> import array_api_strict as xp @@ -379,61 +380,54 @@ def _op( if "Fancy indexing" not in str(e): # Avoid masking other index errors raise e # Work around lack of fancy indexing __setitem__ - if ( + if not ( is_array_api_obj(idx) and xp.isdtype(idx.dtype, "integral") and idx.ndim == 1 ): - # Vectorize the operation using boolean indexing - # For non-unique indices, take the last occurrence. This requires - # masks for x and y that create matching shapes. - # We first create the mask for x - u_idx, _ = xp.unique_inverse(idx) - # Convert negative indices to positive, otherwise they won't get matched - u_idx_pos = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx) - # Check for out of bounds indices - oob_index = None - if is_lazy_array(u_idx_pos): - pass # Lazy arrays cannot check for out of bounds indices - elif xp.any(pos_oob := u_idx_pos >= x.shape[0]): - first_oob_idx = xp.argmax(xp.astype(pos_oob, xp.int32)) - oob_index = int(u_idx[first_oob_idx]) - elif xp.any(neg_oob := u_idx_pos < 0): - first_oob_idx = xp.argmax(xp.astype(neg_oob, xp.int32)) - oob_index = int(u_idx[first_oob_idx]) - if oob_index is not None: - msg = ( - f"index {oob_index} is out of bounds for array of " - f"shape {x.shape}" - ) - raise IndexError(msg) from e - - x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1) - # If y is a scalar or 0D, we are done - if not is_array_api_obj(y) or y.ndim == 0: - x[x_mask] = y - return x - if y.shape[0] != idx.shape[0]: - msg = ( - f"shape mismatch: value array of shape {y.shape} could not be " - f"broadcast to indexing result of shape {idx.shape}" - ) - raise ValueError(msg) from e - # If not, create a mask for y. Get last occurrence of each unique index - cmp = u_idx_pos[:, None] == u_idx_pos[None, :] - # Ignore later matches - lower_tri_mask = ( - xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :] - ) - masked_cmp = cmp & lower_tri_mask - # For each position i, count how many matches occurred before i - prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) - # Last occurrence has highest match count - y_mask = prior_matches == xp.max(prior_matches, axis=-1) - # Apply the operation only to last occurrences - x[x_mask] = y[y_mask] + raise + # Vectorize the operation using boolean indexing + # For non-unique indices, take the last occurrence. This requires + # masks for x and y that create matching shapes. + # We first create the mask for x + # Convert negative indices to positive, otherwise they won't get matched + idx = xp.where(idx < 0, x.shape[0] + idx, idx) + u_idx = xp.sort(xp.unique_values(idx)) + # Check for out of bounds indices + if not is_lazy_array(u_idx) and ( + xp.any(u_idx < 0) or xp.any(u_idx >= x.shape[0]) + ): + msg = f"index or indices out of bounds for array of shape {x.shape}" + raise IndexError(msg) from e + + # Construct a mask for x that is True where x's index is in u_idx. + # Equivalent to np.isin(). + x_rng = xp.arange(x.shape[0], device=_compat.device(u_idx)) + x_mask = xp.any(x_rng[..., None] == u_idx, axis=-1) + # If y is a scalar or 0D, we are done + if not is_array_api_obj(y) or y.ndim == 0: + x[x_mask] = y return x - raise e + if y.shape[0] != idx.shape[0]: + msg = ( + f"shape mismatch: value array of shape {y.shape} could not be " + f"broadcast to indexing result of shape {idx.shape}" + ) + raise ValueError(msg) from e + # If not, create a mask for y. Get last occurrence of each unique index + cmp = idx[:, None] == idx[None, :] + total_matches = xp.sum(xp.astype(cmp, xp.int32), axis=-1) + # Ignore later matches + n = idx.shape[0] + lower_tri_mask = xp.tril(xp.ones((n, n), dtype=xp.bool)) + masked_cmp = cmp & lower_tri_mask + # For each position i, count how many matches occurred before i + prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1) + # Last occurrence has highest match count + y_mask = prior_matches == total_matches + # Apply the operation only to last occurrences + x[x_mask] = y[y_mask] + return x def set( self, diff --git a/tests/test_at.py b/tests/test_at.py index da834383..1843f969 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -288,16 +288,16 @@ def test_setitem_int_array_index(xp: ModuleType): assert isinstance(z, type(x)) xp_assert_equal(z, expect) # Single dimension, non-unique index - x = xp.asarray([0.0, 1.0]) - y = xp.asarray([2.0, 3.0]) - idx = xp.asarray([1, 1]) + x = xp.asarray([0.0, 1.0, 2.0]) + y = xp.asarray([3.0, 4.0, 5.0]) + idx = xp.asarray([0, 1, 0]) device_str = str(get_device(x)).lower() # GPU arrays generally use the first element, but JAX with float64 enabled uses the # last element. if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp): - expect = xp.asarray([0.0, 2.0]) + expect = xp.asarray([3.0, 4.0, 2.0]) else: - expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last + expect = xp.asarray([5.0, 4.0, 2.0]) # CPU arrays use the last z = at_op(x, idx, _AtOp.SET, y) assert isinstance(z, type(x)) xp_assert_equal(z, expect)