Skip to content

Commit e428fa6

Browse files
committed
Unit tests WIP
1 parent 53a4ac9 commit e428fa6

File tree

3 files changed

+221
-31
lines changed

3 files changed

+221
-31
lines changed

array_api_compat/common/_helpers.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -804,17 +804,26 @@ def size(x):
804804
return math.prod(x.shape)
805805

806806

807-
def is_writeable_array(x):
807+
def is_writeable_array(x) -> bool:
808808
"""
809809
Return False if x.__setitem__ is expected to raise; True otherwise
810810
"""
811811
if is_numpy_array(x):
812812
return x.flags.writeable
813-
if is_jax_array(x):
813+
if is_jax_array(x) or is_pydata_sparse_array(x):
814814
return False
815815
return True
816816

817817

818+
def _is_fancy_index(idx) -> bool:
819+
if not isinstance(idx, tuple):
820+
idx = (idx,)
821+
return any(
822+
isinstance(i, (list, tuple)) or is_array_api_obj(i)
823+
for i in idx
824+
)
825+
826+
818827
_undef = object()
819828

820829

@@ -900,7 +909,7 @@ def __getitem__(self, idx):
900909
and feels more intuitive coming from the JAX documentation.
901910
"""
902911
if self.idx is not _undef:
903-
raise TypeError("Index has already been set")
912+
raise ValueError("Index has already been set")
904913
self.idx = idx
905914
return self
906915

@@ -911,14 +920,12 @@ def _common(
911920
copy: bool | None | Literal["_force_false"] = True,
912921
**kwargs,
913922
):
914-
"""Validate kwargs and perform common prepocessing.
923+
"""Perform common prepocessing.
915924
916925
Returns
917926
-------
918-
If the operation can be resolved by at[],
919-
(return value, None)
920-
Otherwise,
921-
(None, preprocessed x)
927+
If the operation can be resolved by at[], (return value, None)
928+
Otherwise, (None, preprocessed x)
922929
"""
923930
if self.idx is _undef:
924931
raise TypeError(
@@ -929,40 +936,44 @@ def _common(
929936
" at(x)[idx].set(value)\n"
930937
"(same for all other methods)."
931938
)
939+
940+
x = self.x
932941

933942
if copy is False:
934-
if not is_writeable_array(self.x):
935-
raise ValueError("Cannot avoid modifying parameter in place")
943+
if not is_writeable_array(x) or is_dask_array(x):
944+
raise ValueError("Cannot modify parameter in place")
936945
elif copy is None:
937-
copy = not is_writeable_array(self.x)
946+
copy = not is_writeable_array(x)
938947
elif copy == "_force_false":
939948
copy = False
940949
elif copy is not True:
941950
raise ValueError(f"Invalid value for copy: {copy!r}")
942951

943-
if copy and is_jax_array(self.x):
952+
if is_jax_array(x):
944953
# Use JAX's at[]
945-
at_ = self.x.at[self.idx]
954+
at_ = x.at[self.idx]
946955
args = (y,) if y is not _undef else ()
947956
return getattr(at_, at_op)(*args, **kwargs), None
948957

949958
# Emulate at[] behaviour for non-JAX arrays
950-
# FIXME We blindly expect the output of x.copy() to be always writeable.
951-
# This holds true for read-only numpy arrays, but not necessarily for
952-
# other backends.
953-
x = self.x.copy() if copy else self.x
959+
if copy:
960+
# FIXME We blindly expect the output of x.copy() to be always writeable.
961+
# This holds true for read-only numpy arrays, but not necessarily for
962+
# other backends.
963+
xp = get_namespace(x)
964+
x = xp.asarray(x, copy=True)
965+
954966
return None, x
955967

956968
def get(self, copy: bool | None = True, **kwargs):
957-
"""Return x[idx]. In addition to plain __getitem__, this allows ensuring
958-
that the output is (not) a copy and kwargs are passed to the backend."""
959-
# Special case when xp=numpy and idx is a fancy index
960-
# If copy is not False, avoid an unnecessary double copy.
961-
# if copy is forced to False, raise.
962-
if is_numpy_array(self.x) and (
963-
isinstance(self.idx, (list, tuple))
964-
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
965-
):
969+
"""
970+
Return x[idx]. In addition to plain __getitem__, this allows ensuring
971+
that the output is (not) a copy and kwargs are passed to the backend.
972+
"""
973+
# __getitem__ with a fancy index always returns a copy.
974+
# Avoid an unnecessary double copy.
975+
# If copy is forced to False, raise.
976+
if _is_fancy_index(self.idx):
966977
if copy is False:
967978
raise ValueError(
968979
"Indexing a numpy array with a fancy index always "
@@ -1032,13 +1043,15 @@ def power(self, y, /, **kwargs):
10321043

10331044
def min(self, y, /, **kwargs):
10341045
"""x[idx] = minimum(x[idx], y)"""
1035-
xp = array_namespace(self.x)
1036-
return self._iop("min", xp.minimum, y, **kwargs)
1046+
import numpy as np
1047+
1048+
return self._iop("min", np.minimum, y, **kwargs)
10371049

10381050
def max(self, y, /, **kwargs):
10391051
"""x[idx] = maximum(x[idx], y)"""
1040-
xp = array_namespace(self.x)
1041-
return self._iop("max", xp.maximum, y, **kwargs)
1052+
import numpy as np
1053+
1054+
return self._iop("max", np.maximum, y, **kwargs)
10421055

10431056

10441057
__all__ = [

tests/test_at.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager, suppress
4+
5+
import numpy as np
6+
import pytest
7+
8+
from array_api_compat import (
9+
array_namespace,
10+
at,
11+
is_dask_array,
12+
is_jax_array,
13+
is_torch_namespace,
14+
is_pydata_sparse_array,
15+
is_writeable_array,
16+
)
17+
from ._helpers import import_, all_libraries
18+
19+
20+
def assert_array_equal(a, b):
21+
if is_pydata_sparse_array(a):
22+
a = a.todense()
23+
elif is_dask_array(a):
24+
a = a.compute()
25+
np.testing.assert_array_equal(a, b)
26+
27+
28+
@contextmanager
29+
def assert_copy(x, copy: bool | None):
30+
# dask arrays are writeable, but writing to them will hot-swap the
31+
# dask graph inside the collection so that anything that references
32+
# the original graph, i.e. the input collection, won't be mutated.
33+
if copy is False and (not is_writeable_array(x) or is_dask_array(x)):
34+
with pytest.raises((TypeError, ValueError)):
35+
yield
36+
return
37+
38+
xp = array_namespace(x)
39+
x_orig = xp.asarray(x, copy=True)
40+
yield
41+
42+
if is_dask_array(x):
43+
expect_copy = True
44+
elif copy is None:
45+
expect_copy = not is_writeable_array(x)
46+
else:
47+
expect_copy = copy
48+
assert_array_equal((x == x_orig).all(), expect_copy)
49+
50+
51+
@pytest.fixture(params=all_libraries + ["np_readonly"])
52+
def x(request):
53+
library = request.param
54+
if library == "np_readonly":
55+
x = np.asarray([10, 20, 30])
56+
x.flags.writeable = False
57+
else:
58+
lib = import_(library)
59+
x = lib.asarray([10, 20, 30])
60+
return x
61+
62+
63+
@pytest.mark.parametrize("copy", [True, False, None])
64+
@pytest.mark.parametrize(
65+
"op,arg,expect",
66+
[
67+
("apply", np.negative, [10, -20, -30]),
68+
("set", 40, [10, 40, 40]),
69+
("add", 40, [10, 60, 70]),
70+
("subtract", 100, [10, -80, -70]),
71+
("multiply", 2, [10, 40, 60]),
72+
("divide", 3, [10, 6, 10]),
73+
("power", 2, [10, 400, 900]),
74+
("min", 25, [10, 20, 25]),
75+
("max", 25, [10, 25, 30]),
76+
],
77+
)
78+
def test_operations(x, copy, op, arg, expect):
79+
with assert_copy(x, copy):
80+
y = getattr(at(x, slice(1, None)), op)(arg, copy=copy)
81+
assert isinstance(y, type(x))
82+
assert_array_equal(y, expect)
83+
84+
85+
@pytest.mark.parametrize("copy", [True, False, None])
86+
def test_get(x, copy):
87+
with assert_copy(x, copy):
88+
y = at(x, slice(2)).get(copy=copy)
89+
assert isinstance(y, type(x))
90+
assert_array_equal(y, [10, 20])
91+
# Let assert_copy test that y is a view or copy
92+
with suppress((TypeError, ValueError)):
93+
y[0] = 40
94+
95+
96+
@pytest.mark.parametrize(
97+
"idx",
98+
[
99+
[0, 1],
100+
(0, 1),
101+
np.array([0, 1], dtype="i1"),
102+
np.array([0, 1], dtype="u1"),
103+
lambda xp: xp.asarray([0, 1], dtype="i1"),
104+
lambda xp: xp.asarray([0, 1], dtype="u1"),
105+
[True, True, False],
106+
(True, True, False),
107+
np.array([True, True, False]),
108+
lambda xp: xp.asarray([True, True, False]),
109+
],
110+
)
111+
@pytest.mark.parametrize("tuple_index", [True, False])
112+
def test_get_fancy_indices(x, idx, tuple_index):
113+
"""get() with a fancy index always returns a copy"""
114+
if callable(idx):
115+
xp = array_namespace(x)
116+
idx = idx(xp)
117+
118+
if is_jax_array(x) and isinstance(idx, (list, tuple)):
119+
pytest.skip("JAX fancy indices must always be arrays")
120+
if is_pydata_sparse_array(x) and is_pydata_sparse_array(idx):
121+
pytest.skip("sparse fancy indices can't be sparse themselves")
122+
if is_dask_array(x) and isinstance(idx, tuple):
123+
pytest.skip("dask does not support tuples; only lists or arrays")
124+
if isinstance(idx, tuple) and not tuple_index:
125+
pytest.skip("tuple indices must always be wrapped in a tuple")
126+
127+
if tuple_index:
128+
idx = (idx,)
129+
130+
with assert_copy(x, True):
131+
y = at(x, idx).get()
132+
assert isinstance(y, type(x))
133+
assert_array_equal(y, [10, 20])
134+
# Let assert_copy test that y is a view or copy
135+
with suppress((TypeError, ValueError)):
136+
y[0] = 40
137+
138+
with assert_copy(x, True):
139+
y = at(x, idx).get(copy=None)
140+
assert isinstance(y, type(x))
141+
assert_array_equal(y, [10, 20])
142+
# Let assert_copy test that y is a view or copy
143+
with suppress((TypeError, ValueError)):
144+
y[0] = 40
145+
146+
with pytest.raises(ValueError, match="fancy index"):
147+
at(x, idx).get(copy=False)
148+
149+
150+
def test_variant_index_syntax(x):
151+
y = at(x)[:2].set(40)
152+
assert isinstance(y, type(x))
153+
assert_array_equal(y, [40, 40, 30])
154+
155+
with pytest.raises(ValueError):
156+
at(x, 1)[2]
157+
with pytest.raises(ValueError):
158+
at(x)[1][2]

tests/test_common.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import is_array_api_obj, device, to_device
8+
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
99

1010
from ._helpers import import_, wrapped_libraries, all_libraries
1111

@@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
5555
assert is_func(lib) == (func == is_namespace_functions[library])
5656

5757

58+
@pytest.mark.parametrize("library", all_libraries)
59+
def test_is_writeable_array(library):
60+
lib = import_(library)
61+
x = lib.asarray([1, 2, 3])
62+
if is_writeable_array(x):
63+
x[1] = 4
64+
np.testing.assert_equal(np.asarray(x), [1, 4, 3])
65+
else:
66+
with pytest.raises((TypeError, ValueError)):
67+
x[1] = 4
68+
69+
70+
def test_is_writeable_array_numpy():
71+
x = np.asarray([1, 2, 3])
72+
assert is_writeable_array(x)
73+
x.flags.writeable = False
74+
assert not is_writeable_array(x)
75+
76+
5877
@pytest.mark.parametrize("library", all_libraries)
5978
def test_device(library):
6079
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)