Skip to content

Commit d0bbc88

Browse files
committed
Reuse alternate index syntax
1 parent c9d45af commit d0bbc88

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/array_api_extra/_funcs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,7 @@ def __getitem__(self, idx: Index, /) -> "at": # numpydoc ignore=PR01,RT01
672672
if self._idx is not _undef:
673673
msg = "Index has already been set"
674674
raise ValueError(msg)
675-
self._idx = idx
676-
return self
675+
return at(self._x, idx)
677676

678677
def _update_common(
679678
self,

tests/test_at.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def test_alternate_index_syntax():
127127
a = np.asarray([1, 2, 3])
128128
assert_array_equal(at(a, 0).set(4), [4, 2, 3])
129129
assert_array_equal(at(a)[0].set(4), [4, 2, 3])
130+
131+
a_at = at(a)
132+
assert_array_equal(a_at[0].add(1), [2, 2, 3])
133+
assert_array_equal(a_at[1].add(2), [1, 4, 3])
134+
130135
with pytest.raises(ValueError, match="Index"):
131136
at(a).set(4)
132137
with pytest.raises(ValueError, match="Index"):

0 commit comments

Comments
 (0)