Skip to content

Commit 7f42065

Browse files
committed
Add fancy __setitem__ support for array-api-strict
1 parent cd3b49f commit 7f42065

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from ._utils import _compat
1212
from ._utils._compat import (
1313
array_namespace,
14+
is_array_api_obj,
15+
is_array_api_strict_namespace,
1416
is_dask_array,
1517
is_jax_array,
1618
is_torch_array,
@@ -352,6 +354,33 @@ def _op(
352354
if is_torch_array(y):
353355
y = xp.astype(y, x.dtype, copy=False)
354356

357+
# Work around lack of fancy indexing __setitem__ support in array-api-strict.
358+
if (
359+
is_array_api_strict_namespace(xp)
360+
and is_array_api_obj(idx)
361+
and xp.isdtype(idx.dtype, "integral")
362+
and out_of_place_op is None # only use for set()
363+
):
364+
# Vectorize the operation using boolean indexing
365+
# For non-unique indices, take the last occurrence. This requires creating
366+
# masks for x and y that create matching shapes.
367+
unique_indices, first_occurrence_mask = xp.unique_inverse(idx)
368+
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == unique_indices, axis=-1)
369+
# Get last occurrence of each unique index
370+
cmp = unique_indices[:, None] == unique_indices[None, :]
371+
# Ignore later matches
372+
lower_tri_mask = (
373+
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
374+
)
375+
masked_cmp = cmp & lower_tri_mask
376+
# For each position i, count how many matches occurred before i
377+
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
378+
# Last occurrence has highest match count
379+
y_mask = prior_matches == xp.max(prior_matches, axis=-1)
380+
# Apply the operation only to last occurrences
381+
x[x_mask] = y[y_mask]
382+
return x
383+
355384
# Backends without boolean indexing (other than JAX) crash here
356385
if in_place_op: # add(), subtract(), ...
357386
x[idx] = in_place_op(x[idx], y)

tests/test_at.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from array_api_extra._lib._at import _AtOp
1212
from array_api_extra._lib._backends import Backend
1313
from array_api_extra._lib._testing import xp_assert_equal
14-
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
14+
from array_api_extra._lib._utils._compat import (
15+
array_namespace,
16+
is_jax_namespace,
17+
is_writeable_array,
18+
)
1519
from array_api_extra._lib._utils._compat import device as get_device
1620
from array_api_extra._lib._utils._typing import Array, Device, SetIndex
1721
from array_api_extra.testing import lazy_xp_function
@@ -272,6 +276,38 @@ def test_bool_mask_nd(xp: ModuleType):
272276
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
273277

274278

279+
def test_setitem_int_array_index(xp: ModuleType):
280+
# Single dimension
281+
x = xp.asarray([0.0, 1.0, 2.0])
282+
y = xp.asarray([3.0, 4.0])
283+
idx = xp.asarray([0, 2])
284+
expect = xp.asarray([3.0, 1.0, 4.0])
285+
z = at_op(x, idx, _AtOp.SET, y)
286+
assert isinstance(z, type(x))
287+
xp_assert_equal(z, expect)
288+
# Single dimension, non-unique index
289+
x = xp.asarray([0.0, 1.0])
290+
y = xp.asarray([2.0, 3.0])
291+
idx = xp.asarray([1, 1])
292+
device_str = str(get_device(x)).lower()
293+
# GPU arrays generally use the first element, but JAX with float64 enabled uses the
294+
# last element.
295+
if ("gpu" in device_str or "cuda" in device_str) and not is_jax_namespace(xp):
296+
expect = xp.asarray([0.0, 2.0])
297+
else:
298+
expect = xp.asarray([0.0, 3.0]) # CPU arrays use the last
299+
z = at_op(x, idx, _AtOp.SET, y)
300+
assert isinstance(z, type(x))
301+
xp_assert_equal(z, expect)
302+
# Multiple dimensions
303+
x = xp.asarray([[0.0, 1.0], [2.0, 3.0]])
304+
y = xp.asarray([[4.0, 5.0]])
305+
idx = xp.asarray([0])
306+
expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]])
307+
z = at_op(x, idx, _AtOp.SET, y)
308+
xp_assert_equal(z, expect)
309+
310+
275311
@pytest.mark.parametrize("bool_mask", [False, True])
276312
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
277313
x = xp.asarray([math.inf, 1.0, 2.0])

0 commit comments

Comments
 (0)