Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ._utils import _compat
from ._utils._compat import (
array_namespace,
is_array_api_obj,
is_array_api_strict_namespace,
is_dask_array,
is_jax_array,
is_torch_array,
Expand Down Expand Up @@ -352,6 +354,33 @@ def _op(
if is_torch_array(y):
y = xp.astype(y, x.dtype, copy=False)

# Work around lack of fancy indexing __setitem__ support in array-api-strict.
if (
is_array_api_strict_namespace(xp)
and is_array_api_obj(idx)
and xp.isdtype(idx.dtype, "integral")
and out_of_place_op is None # only use for set()
):
# Vectorize the operation using boolean indexing
# For non-unique indices, take the last occurrence. This requires creating
# masks for x and y that create matching shapes.
unique_indices, _ = xp.unique_inverse(idx)
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == unique_indices, axis=-1)
# Get last occurrence of each unique index
cmp = unique_indices[:, None] == unique_indices[None, :]
# Ignore later matches
lower_tri_mask = (
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
)
masked_cmp = cmp & lower_tri_mask
# For each position i, count how many matches occurred before i
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
# Last occurrence has highest match count
y_mask = prior_matches == xp.max(prior_matches, axis=-1)
# Apply the operation only to last occurrences
x[x_mask] = y[y_mask]
return x

# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These remain broken.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm not sure if we should attempt to fix them. See my general comment.

Expand Down
38 changes: 37 additions & 1 deletion tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from array_api_extra._lib._at import _AtOp
from array_api_extra._lib._backends import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
from array_api_extra._lib._utils._compat import (
array_namespace,
is_jax_namespace,
is_writeable_array,
)
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Array, Device, SetIndex
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -272,6 +276,38 @@ def test_bool_mask_nd(xp: ModuleType):
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


def test_setitem_int_array_index(xp: ModuleType):
# Single dimension
x = xp.asarray([0.0, 1.0, 2.0])
y = xp.asarray([3.0, 4.0])
idx = xp.asarray([0, 2])
expect = xp.asarray([3.0, 1.0, 4.0])
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Single dimension, non-unique index
x = xp.asarray([0.0, 1.0])
y = xp.asarray([2.0, 3.0])
idx = xp.asarray([1, 1])
device_str = str(get_device(x)).lower()
# GPU arrays generally use the first element, but JAX with float64 enabled uses the
# last element.
if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp):
expect = xp.asarray([0.0, 2.0])
else:
expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last
z = at_op(x, idx, _AtOp.SET, y)
assert isinstance(z, type(x))
xp_assert_equal(z, expect)
# Multiple dimensions
x = xp.asarray([[0.0, 1.0], [2.0, 3.0]])
y = xp.asarray([[4.0, 5.0]])
idx = xp.asarray([0])
expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]])
z = at_op(x, idx, _AtOp.SET, y)
xp_assert_equal(z, expect)


@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
Expand Down