Skip to content

Commit 094577d

Browse files
committed
at().get()
1 parent 6c07b4f commit 094577d

File tree

5 files changed

+160
-18
lines changed

5 files changed

+160
-18
lines changed

src/array_api_extra/_funcs.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from ._lib import _utils
1212
from ._lib._compat import (
1313
array_namespace,
14+
is_array_api_obj,
15+
is_dask_array,
1416
is_jax_array,
17+
is_pydata_sparse_array,
1518
is_writeable_array,
1619
)
1720

@@ -686,6 +689,73 @@ def __getitem__(self, idx: Index, /) -> at:
686689
self._idx = idx
687690
return self
688691

692+
def _check_args(self, /, copy: bool | None) -> None:
693+
if self._idx is _undef:
694+
msg = (
695+
"Index has not been set.\n"
696+
"Usage: either\n"
697+
" at(x, idx).set(value)\n"
698+
"or\n"
699+
" at(x)[idx].set(value)\n"
700+
"(same for all other methods)."
701+
)
702+
raise TypeError(msg)
703+
704+
if copy not in (True, False, None):
705+
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
706+
raise ValueError(msg)
707+
708+
def get(
709+
self,
710+
/,
711+
copy: bool | None = True,
712+
xp: ModuleType | None = None,
713+
) -> Array:
714+
"""Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``,
715+
this allows ensuring that the output is either a copy or a view
716+
"""
717+
self._check_args(copy=copy)
718+
x = self._x
719+
720+
if copy is False:
721+
if is_array_api_obj(self._idx):
722+
# Boolean index. Note that the array API spec
723+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
724+
# does not allow for list, tuple, and tuples of slices plus one or more
725+
# one-dimensional array indices, although many backends support them.
726+
# So this check will encounter a lot of false negatives in real life,
727+
# which can be caught by testing the user code vs. array-api-strict.
728+
msg = "get() with an array index always returns a copy"
729+
raise ValueError(msg)
730+
731+
# Prevent scalar indices together with copy=False.
732+
# Even if some backends may return a scalar view of the original, we chose to be
733+
# strict here beceause some other backends, such as numpy, definitely don't.
734+
tup_idx = self._idx if isinstance(self._idx, tuple) else (self._idx,)
735+
if any(
736+
i is not None and i is not Ellipsis and not isinstance(i, slice)
737+
for i in tup_idx
738+
):
739+
msg = "get() with a scalar index typically returns a copy"
740+
raise ValueError(msg)
741+
742+
# Note: this is not the same list of backends as is_writeable_array()
743+
if is_dask_array(x) or is_jax_array(x) or is_pydata_sparse_array(x):
744+
msg = f"get() on {array_namespace(x)} arrays always returns a copy"
745+
raise ValueError(msg)
746+
747+
if is_jax_array(x):
748+
# Use JAX's at[] or other library that with the same duck-type API
749+
return x.at[self._idx].get()
750+
751+
if xp is None:
752+
xp = array_namespace(x)
753+
# Note: when idx is a boolean mask, numpy always returns a deep copy.
754+
# However, some backends may legitimately return a view when the mask can
755+
# be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
756+
# Err on the side of caution and perform a double-copy in numpy.
757+
return xp.asarray(x[self._idx], copy=copy)
758+
689759
def _update_common(
690760
self,
691761
at_op: str,
@@ -701,23 +771,7 @@ def _update_common(
701771
If the operation can be resolved by at[], (return value, None)
702772
Otherwise, (None, preprocessed x)
703773
"""
704-
x, idx = self._x, self._idx
705-
706-
if idx is _undef:
707-
msg = (
708-
"Index has not been set.\n"
709-
"Usage: either\n"
710-
" at(x, idx).set(value)\n"
711-
"or\n"
712-
" at(x)[idx].set(value)\n"
713-
"(same for all other methods)."
714-
)
715-
raise TypeError(msg)
716-
717-
if copy not in (True, False, None):
718-
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
719-
raise ValueError(msg)
720-
774+
x = self._x
721775
if copy is None:
722776
writeable = is_writeable_array(x)
723777
copy = not writeable
@@ -758,6 +812,7 @@ def set(
758812
xp: ModuleType | None = None,
759813
) -> Array:
760814
"""Apply ``x[idx] = y`` and return the update array"""
815+
self._check_args(copy=copy)
761816
res, x = self._update_common("set", y, copy=copy, xp=xp)
762817
if res is not None:
763818
return res
@@ -785,6 +840,7 @@ def _iop(
785840
Consider for example when x is a numpy array and idx is a fancy index, which
786841
triggers a deep copy on __getitem__.
787842
"""
843+
self._check_args(copy=copy)
788844
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
789845
if res is not None:
790846
return res

src/array_api_extra/_lib/_compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,29 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace,
88
device,
9+
is_array_api_obj,
10+
is_dask_array,
911
is_jax_array,
12+
is_pydata_sparse_array,
1013
is_writeable_array,
1114
)
1215
except ImportError:
1316
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1417
array_namespace,
1518
device,
19+
is_array_api_obj,
20+
is_dask_array,
1621
is_jax_array,
22+
is_pydata_sparse_array,
1723
is_writeable_array,
1824
)
1925

2026
__all__ = (
2127
"array_namespace",
2228
"device",
29+
"is_array_api_obj",
30+
"is_dask_array",
2331
"is_jax_array",
32+
"is_pydata_sparse_array",
2433
"is_writeable_array",
2534
)

src/array_api_extra/_lib/_compat.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,8 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
1416
def is_jax_array(x: object, /) -> bool: ...
17+
def is_pydata_sparse_array(x: object, /) -> bool: ...
1518
def is_writeable_array(x: object, /) -> bool: ...

tests/test_at.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from contextlib import contextmanager
3+
from contextlib import contextmanager, suppress
44
from importlib import import_module
55
from typing import TYPE_CHECKING, Final
66

@@ -9,6 +9,7 @@
99
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
1010
array_namespace,
1111
is_dask_array,
12+
is_numpy_array,
1213
is_pydata_sparse_array,
1314
is_writeable_array,
1415
)
@@ -98,6 +99,71 @@ def test_update_ops(
9899
assert_array_equal(y, expect)
99100

100101

102+
@pytest.mark.parametrize("copy", [True, False, None])
103+
def test_get(array: Array, copy: bool | None):
104+
expect_copy = copy
105+
106+
# dask is mutable, but __getitem__ never returns a view
107+
if is_dask_array(array):
108+
if copy is False:
109+
with pytest.raises(ValueError, match="always returns a copy"):
110+
at(array, slice(2)).get(copy=False)
111+
return
112+
expect_copy = True
113+
114+
# get(copy=False) on a read-only numpy array returns a read-only view
115+
if is_numpy_array(array) and not copy and not array.flags.writeable:
116+
out = at(array, slice(2)).get(copy=copy)
117+
assert_array_equal(out, [10.0, 20.0])
118+
assert out.base is array
119+
assert not out.flags.writeable
120+
return
121+
122+
with assert_copy(array, expect_copy):
123+
y = at(array, slice(2)).get(copy=copy)
124+
assert isinstance(y, type(array))
125+
assert_array_equal(y, [10.0, 20.0])
126+
# Let assert_copy test that y is a view or copy
127+
with suppress(TypeError, ValueError):
128+
y[:] = 40
129+
130+
131+
def test_get_scalar_nocopy(array: Array):
132+
"""get(copy=False) with a scalar index always raises, because some backends
133+
such as numpy and sparse return a np.generic instead of a scalar view
134+
"""
135+
with pytest.raises(ValueError, match="scalar"):
136+
at(array)[0].get(copy=False)
137+
with pytest.raises(ValueError, match="scalar"):
138+
at(array)[(0,)].get(copy=False)
139+
with pytest.raises(ValueError, match="scalar"):
140+
at(array)[..., 0].get(copy=False)
141+
142+
143+
def test_get_bool_indices(array: Array):
144+
"""get() with a boolean array index always returns a copy"""
145+
# sparse violates the array API as it doesn't support
146+
# a boolean index that is another sparse array.
147+
# dask with dask index has NaN size, which complicates testing.
148+
if is_pydata_sparse_array(array) or is_dask_array(array):
149+
xp = np
150+
else:
151+
xp = array_namespace(array)
152+
idx = xp.asarray([True, False, True])
153+
154+
with pytest.raises(ValueError, match="copy"):
155+
at(array, idx).get(copy=False)
156+
157+
assert_array_equal(at(array, idx).get(), [10.0, 30.0])
158+
159+
with assert_copy(array, True):
160+
y = at(array, idx).get(copy=True)
161+
assert_array_equal(y, [10.0, 30.0])
162+
# Let assert_copy test that y is a view or copy
163+
with suppress(TypeError, ValueError):
164+
y[:] = 40
165+
166+
101167
def test_copy_invalid():
102168
a = np.asarray([1, 2, 3])
103169
with pytest.raises(ValueError, match="copy"):
@@ -106,6 +172,7 @@ def test_copy_invalid():
106172

107173
def test_xp():
108174
a = np.asarray([1, 2, 3])
175+
at(a, 0).get(xp=np)
109176
at(a, 0).set(4, xp=np)
110177
at(a, 0).add(4, xp=np)
111178
at(a, 0).subtract(4, xp=np)

vendor_tests/test_vendor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,21 @@ def test_vendor_compat():
88
from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
99
array_namespace,
1010
device,
11+
is_array_api_obj,
12+
is_dask_array,
1113
is_jax_array,
14+
is_pydata_sparse_array,
1215
is_writeable_array,
1316
)
1417

1518
x = xp.asarray([1, 2, 3])
1619
assert array_namespace(x) is xp
1720
device(x)
21+
assert is_array_api_obj(x)
22+
assert not is_array_api_obj(123)
23+
assert not is_dask_array(x)
1824
assert not is_jax_array(x)
25+
assert not is_pydata_sparse_array(x)
1926
assert is_writeable_array(x)
2027

2128

0 commit comments

Comments
 (0)