Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from ._utils import _compat
from ._utils._compat import (
array_namespace,
is_array_api_obj,
is_dask_array,
is_jax_array,
is_lazy_array,
is_torch_array,
is_writeable_array,
)
Expand Down Expand Up @@ -148,6 +150,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
Array([125], dtype=int32)

For frameworks that don't support fancy indexing by default, e.g. array-api-strict,
we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments
with multiple occurrences of the same index always choose the last occurrence. This
is consistent with numpy's behaviour, e.g.::

>>> import numpy as np
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3]))
array([3])
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3]))
Array([3], dtype=array_api_strict.int64)
Comment on lines +159 to +165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example leaves me confused. I don't think it adds anything?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aim is to show that np's and xpx's behavior is identical. For torch tensors on the GPU you would see

>>> xpx.at(torch.tensor([0]).cuda(), torch.tensor([0, 0]).cuda()).set(torch.tensor([2, 3]).cuda())
torch.Tensor([2], dtype=torch.int64)


See Also
--------
jax.numpy.ndarray.at : Equivalent array method in JAX.
Expand Down Expand Up @@ -355,9 +370,70 @@ def _op(
# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These remain broken.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm not sure if we should attempt to fix them. See my general comment.

else: # set()
return x
# set()
try: # We first try to use the backend's __setitem__ if available
x[idx] = y
return x
return x
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
Comment on lines +379 to +381
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite fragile as it cherry-picks array-api-strict's behaviour. Different libraries would have different error messages and different exceptions.

Suggested change
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
except Exception as e:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about this as well. However, I am strongly opposed to a blank except. We would mask errors for regular frameworks that would subsequently enter an unexpected code path which may throw obscure errors. Hence the commend on masking other index errors. This feels almost worse than the added benefit of having array-api-strict support for integer indexing.

# Work around lack of fancy indexing __setitem__
if (
is_array_api_obj(idx)
and xp.isdtype(idx.dtype, "integral")
and idx.ndim == 1
):
# Vectorize the operation using boolean indexing
# For non-unique indices, take the last occurrence. This requires
# masks for x and y that create matching shapes.
# We first create the mask for x
u_idx, _ = xp.unique_inverse(idx)
# Convert negative indices to positive, otherwise they won't get matched
u_idx_pos = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx)
# Check for out of bounds indices
oob_index = None
if is_lazy_array(u_idx_pos):
pass # Lazy arrays cannot check for out of bounds indices
elif xp.any(pos_oob := u_idx_pos >= x.shape[0]):
first_oob_idx = xp.argmax(xp.astype(pos_oob, xp.int32))
oob_index = int(u_idx[first_oob_idx])
elif xp.any(neg_oob := u_idx_pos < 0):
first_oob_idx = xp.argmax(xp.astype(neg_oob, xp.int32))
oob_index = int(u_idx[first_oob_idx])
if oob_index is not None:
msg = (
f"index {oob_index} is out of bounds for array of "
f"shape {x.shape}"
)
raise IndexError(msg) from e

x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1)
x_rng = xp.arange(x.shape[0], device=device(u_idx_pos))
x_mask = xp.any(x_rng[..., None] == u_idx_pos, axis=-1)

Could you add a comment explaning what you're doing here?

Copy link
Contributor

@crusaderky crusaderky Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add in the documentation above a note warning that the implementation is quadratic.

If x is 10 MiB along axis 0 and the u_idx_pos is 10 MiB, this line transitorily consumes 100 terabytes of RAM.
Have you considered using searchsorted?

# If y is a scalar or 0D, we are done
if not is_array_api_obj(y) or y.ndim == 0:
x[x_mask] = y
return x
if y.shape[0] != idx.shape[0]:
msg = (
f"shape mismatch: value array of shape {y.shape} could not be "
f"broadcast to indexing result of shape {idx.shape}"
)
raise ValueError(msg) from e
# If not, create a mask for y. Get last occurrence of each unique index
cmp = u_idx_pos[:, None] == u_idx_pos[None, :]
# Ignore later matches
lower_tri_mask = (
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
)
masked_cmp = cmp & lower_tri_mask
# For each position i, count how many matches occurred before i
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
# Last occurrence has highest match count
y_mask = prior_matches == xp.max(prior_matches, axis=-1)
# Apply the operation only to last occurrences
x[x_mask] = y[y_mask]
return x
raise e

def set(
self,
Expand Down
66 changes: 65 additions & 1 deletion tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from array_api_extra._lib._at import _AtOp
from array_api_extra._lib._backends import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
from array_api_extra._lib._utils._compat import (
array_namespace,
is_array_api_strict_namespace,
is_jax_namespace,
is_numpy_namespace,
is_writeable_array,
)
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Array, Device, SetIndex
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -272,6 +278,64 @@ def test_bool_mask_nd(xp: ModuleType):
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


def test_setitem_int_array_index(xp: ModuleType):
# Single dimension
x = xp.asarray([0.0, 1.0, 2.0])
y = xp.asarray([3.0, 4.0])
idx = xp.asarray([0, 2])
expect = xp.asarray([3.0, 1.0, 4.0])
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Single dimension, non-unique index
x = xp.asarray([0.0, 1.0])
y = xp.asarray([2.0, 3.0])
idx = xp.asarray([1, 1])
device_str = str(get_device(x)).lower()
# GPU arrays generally use the first element, but JAX with float64 enabled uses the
# last element.
if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp):
expect = xp.asarray([0.0, 2.0])
else:
expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Multiple dimensions
x = xp.asarray([[0.0, 1.0], [2.0, 3.0]])
y = xp.asarray([[4.0, 5.0]])
idx = xp.asarray([0])
expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]])
z = at_op(x, idx, _AtOp.SET, y)
xp_assert_equal(z, expect)
# Scalar
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([1]), _AtOp.SET, 2.0)
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# 0D array
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([1]), _AtOp.SET, xp.asarray(2.0))
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# Negative indices
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0)
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# Different frameworks have all kinds of different behaviours for negative indices,
# out-of-bounds indices, etc. Therefore, we only test the behaviour of two
# frameworks: numpy because we state in the docs that it is our reference for the
# behaviour of other frameworks with no native support, and array-api-strict.
if is_array_api_strict_namespace(xp) or is_numpy_namespace(xp):
# Test wrong shapes
with pytest.raises(ValueError, match="shape"):
_ = at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2]))
# Test positive out of bounds index
with pytest.raises(IndexError, match="out of bounds"):
_ = at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1]))
# Test negative out of bounds index
with pytest.raises(IndexError, match="out of bounds"):
_ = at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1]))


@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
Expand Down