Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ._utils import _compat
from ._utils._compat import (
array_namespace,
is_array_api_obj,
is_array_api_strict_namespace,
is_dask_array,
is_jax_array,
is_torch_array,
Expand Down Expand Up @@ -148,6 +150,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
Array([125], dtype=int32)

For frameworks that don't support fancy indexing by default, e.g. array-api-strict,
we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments
with multiple occurences of the same index always choose the last occurence. This is
consistent with numpy's behaviour, e.g.::

>>> import numpy as np
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3]))
array([3])
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3]))
Array([3], dtype=array_api_strict.int64)
Comment on lines +159 to +165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example leaves me confused. I don't think it adds anything?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aim is to show that np's and xpx's behavior is identical. For torch tensors on the GPU you would see

>>> xpx.at(torch.tensor([0]).cuda(), torch.tensor([0, 0]).cuda()).set(torch.tensor([2, 3]).cuda())
torch.Tensor([2], dtype=torch.int64)


See Also
--------
jax.numpy.ndarray.at : Equivalent array method in JAX.
Expand Down Expand Up @@ -355,9 +370,47 @@ def _op(
# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These remain broken.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm not sure if we should attempt to fix them. See my general comment.

else: # set()
return x
# set()
try: # We first try to use the backend's __setitem__ if available
x[idx] = y
return x
return x
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
Comment on lines +379 to +381
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite fragile as it cherry-picks array-api-strict's behaviour. Different libraries would have different error messages and different exceptions.

Suggested change
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
except Exception as e:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about this as well. However, I am strongly opposed to a blank except. We would mask errors for regular frameworks that would subsequently enter an unexpected code path which may throw obscure errors. Hence the commend on masking other index errors. This feels almost worse than the added benefit of having array-api-strict support for integer indexing.

# Work around lack of fancy indexing __setitem__
if (
is_array_api_obj(idx)
and xp.isdtype(idx.dtype, "integral")
and idx.ndim == 1
):
# Vectorize the operation using boolean indexing
# For non-unique indices, take the last occurrence. This requires
# masks for x and y that create matching shapes.
# We first create the mask for x
u_idx, _ = xp.unique_inverse(idx)
# Convert negative indices to positive, otherwise they won't get matched
u_idx = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx)
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx, axis=-1)
# If y is a scalar or 0D, we are done
if not is_array_api_obj(y) or y.ndim == 0:
x[x_mask] = y
return x
# If not, create a mask for y. Get last occurrence of each unique index
cmp = u_idx[:, None] == u_idx[None, :]
# Ignore later matches
lower_tri_mask = (
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
)
masked_cmp = cmp & lower_tri_mask
# For each position i, count how many matches occurred before i
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
# Last occurrence has highest match count
y_mask = prior_matches == xp.max(prior_matches, axis=-1)
# Apply the operation only to last occurrences
x[x_mask] = y[y_mask]
return x
raise e

def set(
self,
Expand Down
50 changes: 49 additions & 1 deletion tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from array_api_extra._lib._at import _AtOp
from array_api_extra._lib._backends import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
from array_api_extra._lib._utils._compat import (
array_namespace,
is_jax_namespace,
is_writeable_array,
)
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Array, Device, SetIndex
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -272,6 +276,50 @@ def test_bool_mask_nd(xp: ModuleType):
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


def test_setitem_int_array_index(xp: ModuleType):
# Single dimension
x = xp.asarray([0.0, 1.0, 2.0])
y = xp.asarray([3.0, 4.0])
idx = xp.asarray([0, 2])
expect = xp.asarray([3.0, 1.0, 4.0])
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Single dimension, non-unique index
x = xp.asarray([0.0, 1.0])
y = xp.asarray([2.0, 3.0])
idx = xp.asarray([1, 1])
device_str = str(get_device(x)).lower()
# GPU arrays generally use the first element, but JAX with float64 enabled uses the
# last element.
if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp):
expect = xp.asarray([0.0, 2.0])
else:
expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Multiple dimensions
x = xp.asarray([[0.0, 1.0], [2.0, 3.0]])
y = xp.asarray([[4.0, 5.0]])
idx = xp.asarray([0])
expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]])
z = at_op(x, idx, _AtOp.SET, y)
xp_assert_equal(z, expect)
# Scalar
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([1]), _AtOp.SET, 2.0)
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# 0D array
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([1]), _AtOp.SET, xp.asarray(2.0))
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
# Negative indices
x = xp.asarray([0.0, 1.0])
z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0)
xp_assert_equal(z, xp.asarray([0.0, 2.0]))


@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
Expand Down