Skip to content

Commit 6fe2ffb

Browse files
committed
Add handling and tests for negative and positive out-of-bounds indices
1 parent 2c5e0aa commit 6fe2ffb

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from ._utils._compat import (
1313
array_namespace,
1414
is_array_api_obj,
15-
is_array_api_strict_namespace,
1615
is_dask_array,
1716
is_jax_array,
17+
is_lazy_array,
1818
is_torch_array,
1919
is_writeable_array,
2020
)
@@ -390,14 +390,37 @@ def _op(
390390
# We first create the mask for x
391391
u_idx, _ = xp.unique_inverse(idx)
392392
# Convert negative indices to positive, otherwise they won't get matched
393-
u_idx = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx)
394-
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx, axis=-1)
393+
u_idx_pos = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx)
394+
# Check for out of bounds indices
395+
oob_index = None
396+
if is_lazy_array(u_idx_pos):
397+
pass # Lazy arrays cannot check for out of bounds indices
398+
elif xp.any(pos_oob := u_idx_pos >= x.shape[0]):
399+
first_oob_idx = xp.argmax(xp.astype(pos_oob, xp.int32))
400+
oob_index = int(u_idx[first_oob_idx])
401+
elif xp.any(neg_oob := u_idx_pos < 0):
402+
first_oob_idx = xp.argmax(xp.astype(neg_oob, xp.int32))
403+
oob_index = int(u_idx[first_oob_idx])
404+
if oob_index is not None:
405+
msg = (
406+
f"index {oob_index} is out of bounds for array of "
407+
f"shape {x.shape}"
408+
)
409+
raise IndexError(msg) from e
410+
411+
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1)
395412
# If y is a scalar or 0D, we are done
396413
if not is_array_api_obj(y) or y.ndim == 0:
397414
x[x_mask] = y
398415
return x
416+
if y.shape[0] != idx.shape[0]:
417+
msg = (
418+
f"shape mismatch: value array of shape {y.shape} could not be "
419+
f"broadcast to indexing result of shape {idx.shape}"
420+
)
421+
raise ValueError(msg) from e
399422
# If not, create a mask for y. Get last occurrence of each unique index
400-
cmp = u_idx[:, None] == u_idx[None, :]
423+
cmp = u_idx_pos[:, None] == u_idx_pos[None, :]
401424
# Ignore later matches
402425
lower_tri_mask = (
403426
xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]

tests/test_at.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from array_api_extra._lib._testing import xp_assert_equal
1414
from array_api_extra._lib._utils._compat import (
1515
array_namespace,
16+
is_array_api_strict_namespace,
1617
is_jax_namespace,
18+
is_numpy_namespace,
1719
is_writeable_array,
1820
)
1921
from array_api_extra._lib._utils._compat import device as get_device
@@ -318,6 +320,20 @@ def test_setitem_int_array_index(xp: ModuleType):
318320
x = xp.asarray([0.0, 1.0])
319321
z = at_op(x, xp.asarray([-1]), _AtOp.SET, 2.0)
320322
xp_assert_equal(z, xp.asarray([0.0, 2.0]))
323+
# Different frameworks have all kinds of different behaviours for negative indices,
324+
# out-of-bounds indices, etc. Therefore, we only test the behaviour of two
325+
# frameworks: numpy because we state in the docs that it is our reference for the
326+
# behaviour of other frameworks with no native support, and array-api-strict.
327+
if is_array_api_strict_namespace(xp) or is_numpy_namespace(xp):
328+
# Test wrong shapes
329+
with pytest.raises(ValueError, match="shape"):
330+
at_op(xp.asarray([0]), xp.asarray([0]), _AtOp.SET, xp.asarray([1, 2]))
331+
# Test positive out of bounds index
332+
with pytest.raises(IndexError, match="out of bounds"):
333+
at_op(xp.asarray([0]), xp.asarray([1]), _AtOp.SET, xp.asarray([1]))
334+
# Test negative out of bounds index
335+
with pytest.raises(IndexError, match="out of bounds"):
336+
at_op(xp.asarray([0]), xp.asarray([-2]), _AtOp.SET, xp.asarray([1]))
321337

322338

323339
@pytest.mark.parametrize("bool_mask", [False, True])

0 commit comments

Comments
 (0)