Skip to content

Commit 2c5e0aa

Browse files
committed
Add docs. Fix 0D and scalar cases. Handle negative indices
1 parent 9df1c65 commit 2c5e0aa

File tree

2 files changed

+65
-29
lines changed

2 files changed

+65
-29
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
150150
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
151151
Array([125], dtype=int32)
152152
153+
For frameworks that don't support fancy indexing by default, e.g. array-api-strict,
154+
we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments
155+
with multiple occurences of the same index always choose the last occurence. This is
156+
consistent with numpy's behaviour, e.g.::
157+
158+
>>> import numpy as np
159+
>>> import array_api_strict as xp
160+
>>> import array_api_extra as xpx
161+
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3]))
162+
array([3])
163+
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3]))
164+
Array([3], dtype=array_api_strict.int64)
165+
153166
See Also
154167
--------
155168
jax.numpy.ndarray.at : Equivalent array method in JAX.
@@ -354,39 +367,50 @@ def _op(
354367
if is_torch_array(y):
355368
y = xp.astype(y, x.dtype, copy=False)
356369

357-
# Work around lack of fancy indexing __setitem__ support in array-api-strict.
358-
if (
359-
is_array_api_strict_namespace(xp)
360-
and is_array_api_obj(idx)
361-
and xp.isdtype(idx.dtype, "integral")
362-
and out_of_place_op is None # only use for set()
363-
):
364-
# Vectorize the operation using boolean indexing
365-
# For non-unique indices, take the last occurrence. This requires creating
366-
# masks for x and y that create matching shapes.
367-
unique_indices, _ = xp.unique_inverse(idx)
368-
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == unique_indices, axis=-1)
369-
# Get last occurrence of each unique index
370-
cmp = unique_indices[:, None] == unique_indices[None, :]
371-
# Ignore later matches
372-
lower_tri_mask = (
373-
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
374-
)
375-
masked_cmp = cmp & lower_tri_mask
376-
# For each position i, count how many matches occurred before i
377-
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
378-
# Last occurrence has highest match count
379-
y_mask = prior_matches == xp.max(prior_matches, axis=-1)
380-
# Apply the operation only to last occurrences
381-
x[x_mask] = y[y_mask]
382-
return x
383-
384370
# Backends without boolean indexing (other than JAX) crash here
385371
if in_place_op: # add(), subtract(), ...
386372
x[idx] = in_place_op(x[idx], y)
387-
else: # set()
373+
return x
374+
# set()
375+
try: # We first try to use the backend's __setitem__ if available
388376
x[idx] = y
389-
return x
377+
return x
378+
except IndexError as e:
379+
if "Fancy indexing" not in str(e): # Avoid masking other index errors
380+
raise e
381+
# Work around lack of fancy indexing __setitem__
382+
if (
383+
is_array_api_obj(idx)
384+
and xp.isdtype(idx.dtype, "integral")
385+
and idx.ndim == 1
386+
):
387+
# Vectorize the operation using boolean indexing
388+
# For non-unique indices, take the last occurrence. This requires
389+
# masks for x and y that create matching shapes.
390+
# We first create the mask for x
391+
u_idx, _ = xp.unique_inverse(idx)
392+
# Convert negative indices to positive, otherwise they won't get matched
393+
u_idx = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx)
394+
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx, axis=-1)
395+
# If y is a scalar or 0D, we are done
396+
if not is_array_api_obj(y) or y.ndim == 0:
397+
x[x_mask] = y
398+
return x
399+
# If not, create a mask for y. Get last occurrence of each unique index
400+
cmp = u_idx[:, None] == u_idx[None, :]
401+
# Ignore later matches
402+
lower_tri_mask = (
403+
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
404+
)
405+
masked_cmp = cmp & lower_tri_mask
406+
# For each position i, count how many matches occurred before i
407+
prior_matches = xp.sum(xp.astype(masked_cmp, xp.int32), axis=-1)
408+
# Last occurrence has highest match count
409+
y_mask = prior_matches == xp.max(prior_matches, axis=-1)
410+
# Apply the operation only to last occurrences
411+
x[x_mask] = y[y_mask]
412+
return x
413+
raise e
390414

391415
def set(
392416
self,

tests/test_at.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,18 @@ def test_setitem_int_array_index(xp: ModuleType):
306306
expect = xp.asarray([[4.0, 5.0], [2.0, 3.0]])
307307
z = at_op(x, idx, _AtOp.SET, y)
308308
xp_assert_equal(z, expect)
309+
# Scalar
310+
x = xp.asarray([0.0, 1.0])
311+
z = at_op(x, xp.asarray([1]), _AtOp.SET, 2.0)
312+
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
313+
# 0D array
314+
x = xp.asarray([0.0, 1.0])
315+
z = at_op(x, xp.asarray([1]), _AtOp.SET, xp.asarray(2.0))
316+
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
317+
# Negative indices
318+
x = xp.asarray([0.0, 1.0])
319+
z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0)
320+
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
309321

310322

311323
@pytest.mark.parametrize("bool_mask", [False, True])

0 commit comments

Comments
 (0)