Skip to content

Commit 6c07b4f

Browse files
committed
Remove get()
1 parent 1e38e00 commit 6c07b4f

File tree

5 files changed

+18
-160
lines changed

5 files changed

+18
-160
lines changed

src/array_api_extra/_funcs.py

Lines changed: 17 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from ._lib import _utils
1212
from ._lib._compat import (
1313
array_namespace,
14-
is_array_api_obj,
15-
is_dask_array,
1614
is_jax_array,
17-
is_pydata_sparse_array,
1815
is_writeable_array,
1916
)
2017

@@ -689,73 +686,6 @@ def __getitem__(self, idx: Index, /) -> at:
689686
self._idx = idx
690687
return self
691688

692-
def _check_args(self, /, copy: bool | None) -> None:
693-
if self._idx is _undef:
694-
msg = (
695-
"Index has not been set.\n"
696-
"Usage: either\n"
697-
" at(x, idx).set(value)\n"
698-
"or\n"
699-
" at(x)[idx].set(value)\n"
700-
"(same for all other methods)."
701-
)
702-
raise TypeError(msg)
703-
704-
if copy not in (True, False, None):
705-
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
706-
raise ValueError(msg)
707-
708-
def get(
709-
self,
710-
/,
711-
copy: bool | None = True,
712-
xp: ModuleType | None = None,
713-
) -> Array:
714-
"""Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``,
715-
this allows ensuring that the output is either a copy or a view
716-
"""
717-
self._check_args(copy=copy)
718-
x = self._x
719-
720-
if copy is False:
721-
if is_array_api_obj(self._idx):
722-
# Boolean index. Note that the array API spec
723-
# https://data-apis.org/array-api/latest/API_specification/indexing.html
724-
# does not allow for list, tuple, and tuples of slices plus one or more
725-
# one-dimensional array indices, although many backends support them.
726-
# So this check will encounter a lot of false negatives in real life,
727-
# which can be caught by testing the user code vs. array-api-strict.
728-
msg = "get() with an array index always returns a copy"
729-
raise ValueError(msg)
730-
731-
# Prevent scalar indices together with copy=False.
732-
# Even if some backends may return a scalar view of the original, we chose to be
733-
# strict here beceause some other backends, such as numpy, definitely don't.
734-
tup_idx = self._idx if isinstance(self._idx, tuple) else (self._idx,)
735-
if any(
736-
i is not None and i is not Ellipsis and not isinstance(i, slice)
737-
for i in tup_idx
738-
):
739-
msg = "get() with a scalar index typically returns a copy"
740-
raise ValueError(msg)
741-
742-
# Note: this is not the same list of backends as is_writeable_array()
743-
if is_dask_array(x) or is_jax_array(x) or is_pydata_sparse_array(x):
744-
msg = f"get() on {array_namespace(x)} arrays always returns a copy"
745-
raise ValueError(msg)
746-
747-
if is_jax_array(x):
748-
# Use JAX's at[] or other library that with the same duck-type API
749-
return x.at[self._idx].get()
750-
751-
if xp is None:
752-
xp = array_namespace(x)
753-
# Note: when idx is a boolean mask, numpy always returns a deep copy.
754-
# However, some backends may legitimately return a view when the mask can
755-
# be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
756-
# Err on the side of caution and perform a double-copy in numpy.
757-
return xp.asarray(x[self._idx], copy=copy)
758-
759689
def _update_common(
760690
self,
761691
at_op: str,
@@ -771,7 +701,23 @@ def _update_common(
771701
If the operation can be resolved by at[], (return value, None)
772702
Otherwise, (None, preprocessed x)
773703
"""
774-
x = self._x
704+
x, idx = self._x, self._idx
705+
706+
if idx is _undef:
707+
msg = (
708+
"Index has not been set.\n"
709+
"Usage: either\n"
710+
" at(x, idx).set(value)\n"
711+
"or\n"
712+
" at(x)[idx].set(value)\n"
713+
"(same for all other methods)."
714+
)
715+
raise TypeError(msg)
716+
717+
if copy not in (True, False, None):
718+
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
719+
raise ValueError(msg)
720+
775721
if copy is None:
776722
writeable = is_writeable_array(x)
777723
copy = not writeable
@@ -812,7 +758,6 @@ def set(
812758
xp: ModuleType | None = None,
813759
) -> Array:
814760
"""Apply ``x[idx] = y`` and return the update array"""
815-
self._check_args(copy=copy)
816761
res, x = self._update_common("set", y, copy=copy, xp=xp)
817762
if res is not None:
818763
return res
@@ -840,7 +785,6 @@ def _iop(
840785
Consider for example when x is a numpy array and idx is a fancy index, which
841786
triggers a deep copy on __getitem__.
842787
"""
843-
self._check_args(copy=copy)
844788
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
845789
if res is not None:
846790
return res

src/array_api_extra/_lib/_compat.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,20 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace,
88
device,
9-
is_array_api_obj,
10-
is_dask_array,
119
is_jax_array,
12-
is_pydata_sparse_array,
1310
is_writeable_array,
1411
)
1512
except ImportError:
1613
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1714
array_namespace,
1815
device,
19-
is_array_api_obj,
20-
is_dask_array,
2116
is_jax_array,
22-
is_pydata_sparse_array,
2317
is_writeable_array,
2418
)
2519

2620
__all__ = (
2721
"array_namespace",
2822
"device",
29-
"is_array_api_obj",
30-
"is_dask_array",
3123
"is_jax_array",
32-
"is_pydata_sparse_array",
3324
"is_writeable_array",
3425
)

src/array_api_extra/_lib/_compat.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,5 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14-
def is_array_api_obj(x: object, /) -> bool: ...
15-
def is_dask_array(x: object, /) -> bool: ...
1614
def is_jax_array(x: object, /) -> bool: ...
17-
def is_pydata_sparse_array(x: object, /) -> bool: ...
1815
def is_writeable_array(x: object, /) -> bool: ...

tests/test_at.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from contextlib import contextmanager, suppress
3+
from contextlib import contextmanager
44
from importlib import import_module
55
from typing import TYPE_CHECKING, Final
66

@@ -9,7 +9,6 @@
99
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
1010
array_namespace,
1111
is_dask_array,
12-
is_numpy_array,
1312
is_pydata_sparse_array,
1413
is_writeable_array,
1514
)
@@ -99,71 +98,6 @@ def test_update_ops(
9998
assert_array_equal(y, expect)
10099

101100

102-
@pytest.mark.parametrize("copy", [True, False, None])
103-
def test_get(array: Array, copy: bool | None):
104-
expect_copy = copy
105-
106-
# dask is mutable, but __getitem__ never returns a view
107-
if is_dask_array(array):
108-
if copy is False:
109-
with pytest.raises(ValueError, match="always returns a copy"):
110-
at(array, slice(2)).get(copy=False)
111-
return
112-
expect_copy = True
113-
114-
# get(copy=False) on a read-only numpy array returns a read-only view
115-
if is_numpy_array(array) and not copy and not array.flags.writeable:
116-
out = at(array, slice(2)).get(copy=copy)
117-
assert_array_equal(out, [10.0, 20.0])
118-
assert out.base is array
119-
assert not out.flags.writeable
120-
return
121-
122-
with assert_copy(array, expect_copy):
123-
y = at(array, slice(2)).get(copy=copy)
124-
assert isinstance(y, type(array))
125-
assert_array_equal(y, [10.0, 20.0])
126-
# Let assert_copy test that y is a view or copy
127-
with suppress(TypeError, ValueError):
128-
y[:] = 40
129-
130-
131-
def test_get_scalar_nocopy(array: Array):
132-
"""get(copy=False) with a scalar index always raises, because some backends
133-
such as numpy and sparse return a np.generic instead of a scalar view
134-
"""
135-
with pytest.raises(ValueError, match="scalar"):
136-
at(array)[0].get(copy=False)
137-
with pytest.raises(ValueError, match="scalar"):
138-
at(array)[(0,)].get(copy=False)
139-
with pytest.raises(ValueError, match="scalar"):
140-
at(array)[..., 0].get(copy=False)
141-
142-
143-
def test_get_bool_indices(array: Array):
144-
"""get() with a boolean array index always returns a copy"""
145-
# sparse violates the array API as it doesn't support
146-
# a boolean index that is another sparse array.
147-
# dask with dask index has NaN size, which complicates testing.
148-
if is_pydata_sparse_array(array) or is_dask_array(array):
149-
xp = np
150-
else:
151-
xp = array_namespace(array)
152-
idx = xp.asarray([True, False, True])
153-
154-
with pytest.raises(ValueError, match="copy"):
155-
at(array, idx).get(copy=False)
156-
157-
assert_array_equal(at(array, idx).get(), [10.0, 30.0])
158-
159-
with assert_copy(array, True):
160-
y = at(array, idx).get(copy=True)
161-
assert_array_equal(y, [10.0, 30.0])
162-
# Let assert_copy test that y is a view or copy
163-
with suppress(TypeError, ValueError):
164-
y[:] = 40
165-
166-
167101
def test_copy_invalid():
168102
a = np.asarray([1, 2, 3])
169103
with pytest.raises(ValueError, match="copy"):
@@ -172,7 +106,6 @@ def test_copy_invalid():
172106

173107
def test_xp():
174108
a = np.asarray([1, 2, 3])
175-
at(a, 0).get(xp=np)
176109
at(a, 0).set(4, xp=np)
177110
at(a, 0).add(4, xp=np)
178111
at(a, 0).subtract(4, xp=np)

vendor_tests/test_vendor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,14 @@ def test_vendor_compat():
88
from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
99
array_namespace,
1010
device,
11-
is_array_api_obj,
12-
is_dask_array,
1311
is_jax_array,
14-
is_pydata_sparse_array,
1512
is_writeable_array,
1613
)
1714

1815
x = xp.asarray([1, 2, 3])
1916
assert array_namespace(x) is xp
2017
device(x)
21-
assert is_array_api_obj(x)
22-
assert not is_array_api_obj(123)
23-
assert not is_dask_array(x)
2418
assert not is_jax_array(x)
25-
assert not is_pydata_sparse_array(x)
2619
assert is_writeable_array(x)
2720

2821

0 commit comments

Comments
 (0)