Skip to content

Commit 2a187a2

Browse files
committed
Merge branch 'main' into torch_meta
2 parents fd7a799 + dc44205 commit 2a187a2

File tree

1 file changed

+132
-139
lines changed

1 file changed

+132
-139
lines changed

tests/test_testing.py

Lines changed: 132 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Callable
2-
from contextlib import nullcontext
32
from types import ModuleType
43
from typing import cast
54

@@ -21,23 +20,9 @@
2120
from array_api_extra._lib._utils._typing import Array, Device
2221
from array_api_extra.testing import lazy_xp_function
2322

24-
# mypy: disable-error-code=decorated-any
23+
# mypy: disable-error-code="decorated-any, explicit-any"
2524
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
2625

27-
param_assert_equal_close = pytest.mark.parametrize(
28-
"func",
29-
[
30-
xp_assert_equal,
31-
xp_assert_less,
32-
pytest.param(
33-
xp_assert_close,
34-
marks=pytest.mark.xfail_xp_backend(
35-
Backend.SPARSE, reason="no isdtype", strict=False
36-
),
37-
),
38-
],
39-
)
40-
4126

4227
class TestAsNumPyArray:
4328
def test_basic(self, xp: ModuleType):
@@ -57,136 +42,144 @@ def test_device(self, xp: ModuleType, library: Backend, device: Device):
5742
xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
5843

5944

60-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
61-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
62-
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
63-
func(xp.asarray(0), xp.asarray(0))
64-
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
65-
66-
with pytest.raises(AssertionError, match="shapes do not match"):
67-
func(xp.asarray([0]), xp.asarray([[0]]))
68-
69-
with pytest.raises(AssertionError, match="dtypes do not match"):
70-
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
71-
72-
with pytest.raises(AssertionError):
73-
func(xp.asarray([1, 2]), xp.asarray([1, 3]))
74-
75-
with pytest.raises(AssertionError, match="hello"):
76-
func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello")
77-
78-
79-
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
80-
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
81-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
82-
def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
83-
with pytest.raises(AssertionError, match="namespaces do not match"):
84-
func(xp.asarray(0), np.asarray(0))
85-
with pytest.raises(TypeError, match="Unrecognized array input"):
86-
func(xp.asarray(0), 0)
87-
with pytest.raises(TypeError, match="list is not a supported array type"):
88-
func(xp.asarray([0]), [0])
89-
90-
91-
@param_assert_equal_close
92-
@pytest.mark.parametrize("check_shape", [False, True])
93-
def test_assert_close_equal_less_shape( # type: ignore[explicit-any]
94-
xp: ModuleType,
95-
func: Callable[..., None],
96-
check_shape: bool,
97-
):
98-
context = (
99-
pytest.raises(AssertionError, match="shapes do not match")
100-
if check_shape
101-
else nullcontext()
102-
)
103-
with context:
104-
# note: NaNs are handled by all 3 checks
105-
func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)
106-
107-
108-
@param_assert_equal_close
109-
@pytest.mark.parametrize("check_dtype", [False, True])
110-
def test_assert_close_equal_less_dtype( # type: ignore[explicit-any]
111-
xp: ModuleType,
112-
func: Callable[..., None],
113-
check_dtype: bool,
114-
):
115-
context = (
116-
pytest.raises(AssertionError, match="dtypes do not match")
117-
if check_dtype
118-
else nullcontext()
45+
class TestAssertEqualCloseLess:
46+
pr_assert_close = pytest.param( # pyright: ignore[reportUnannotatedClassAttribute]
47+
xp_assert_close,
48+
marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
11949
)
120-
with context:
121-
func(
122-
xp.asarray(xp.nan, dtype=xp.float32),
123-
xp.asarray(xp.nan, dtype=xp.float64),
124-
check_dtype=check_dtype,
125-
)
126-
127-
128-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
129-
@pytest.mark.parametrize("check_scalar", [False, True])
130-
def test_assert_close_equal_less_scalar( # type: ignore[explicit-any]
131-
xp: ModuleType,
132-
func: Callable[..., None],
133-
check_scalar: bool,
134-
):
135-
context = (
136-
pytest.raises(AssertionError, match="array-ness does not match")
137-
if check_scalar
138-
else nullcontext()
139-
)
140-
with context:
141-
func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar)
142-
14350

144-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
145-
def test_assert_close_tolerance(xp: ModuleType):
146-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)
147-
with pytest.raises(AssertionError):
148-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.01)
51+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close])
52+
def test_assert_equal_close_basic(self, xp: ModuleType, func: Callable[..., None]):
53+
func(xp.asarray(0), xp.asarray(0))
54+
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
14955

150-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
151-
with pytest.raises(AssertionError):
152-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
56+
with pytest.raises(AssertionError, match="Mismatched elements"):
57+
func(xp.asarray([1, 2]), xp.asarray([2, 1]))
15358

59+
with pytest.raises(AssertionError, match="hello"):
60+
func(xp.asarray([1, 2]), xp.asarray([2, 1]), err_msg="hello")
15461

155-
def test_assert_less_basic(xp: ModuleType):
156-
xp_assert_less(xp.asarray(-1), xp.asarray(0))
157-
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
158-
with pytest.raises(AssertionError):
159-
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
160-
with pytest.raises(AssertionError, match="hello"):
161-
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello")
62+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
63+
def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]):
64+
with pytest.raises(AssertionError, match="shapes do not match"):
65+
func(xp.asarray([0]), xp.asarray([[0]]))
16266

67+
with pytest.raises(AssertionError, match="dtypes do not match"):
68+
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
16369

164-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
165-
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
166-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
167-
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
168-
"""On Dask and other lazy backends, test that a shape with NaN's or None's
169-
can be compared to a real shape.
170-
"""
171-
a = xp.asarray([1, 2])
172-
a = a[a > 1]
173-
174-
func(a, xp.asarray([2]))
175-
with pytest.raises(AssertionError):
176-
func(a, xp.asarray([2, 3]))
177-
with pytest.raises(AssertionError):
178-
func(a, xp.asarray(2))
179-
with pytest.raises(AssertionError):
180-
func(a, xp.asarray([3]))
181-
182-
# Swap actual and desired
183-
func(xp.asarray([2]), a)
184-
with pytest.raises(AssertionError):
185-
func(xp.asarray([2, 3]), a)
186-
with pytest.raises(AssertionError):
187-
func(xp.asarray(2), a)
188-
with pytest.raises(AssertionError):
189-
func(xp.asarray([3]), a)
70+
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
71+
@pytest.mark.skip_xp_backend(
72+
Backend.NUMPY_READONLY, reason="test other ns vs. numpy"
73+
)
74+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
75+
def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
76+
with pytest.raises(AssertionError, match="namespaces do not match"):
77+
func(xp.asarray(0), np.asarray(0))
78+
with pytest.raises(TypeError, match="Unrecognized array input"):
79+
func(xp.asarray(0), 0)
80+
with pytest.raises(TypeError, match="list is not a supported array type"):
81+
func(xp.asarray([0]), [0])
82+
83+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
84+
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
85+
a = xp.asarray([1] if func is xp_assert_less else [2])
86+
b = xp.asarray(2)
87+
c = xp.asarray(0)
88+
d = xp.asarray([2, 2])
89+
90+
with pytest.raises(AssertionError, match="shapes do not match"):
91+
func(a, b)
92+
func(a, b, check_shape=False)
93+
with pytest.raises(AssertionError, match="Mismatched elements"):
94+
func(a, c, check_shape=False)
95+
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
96+
func(a, d, check_shape=False)
97+
98+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
99+
def test_check_dtype(self, xp: ModuleType, func: Callable[..., None]):
100+
a = xp.asarray(1 if func is xp_assert_less else 2)
101+
b = xp.asarray(2, dtype=xp.int16)
102+
c = xp.asarray(0, dtype=xp.int16)
103+
104+
with pytest.raises(AssertionError, match="dtypes do not match"):
105+
func(a, b)
106+
func(a, b, check_dtype=False)
107+
with pytest.raises(AssertionError, match="Mismatched elements"):
108+
func(a, c, check_dtype=False)
109+
110+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
111+
@pytest.mark.xfail_xp_backend(
112+
Backend.SPARSE, reason="sparse [()] returns np.generic"
113+
)
114+
def test_check_scalar(
115+
self, xp: ModuleType, library: Backend, func: Callable[..., None]
116+
):
117+
a = xp.asarray(1 if func is xp_assert_less else 2)
118+
b = xp.asarray(2)[()] # Note: only makes a difference on NumPy
119+
c = xp.asarray(0)
120+
121+
func(a, b)
122+
if library.like(Backend.NUMPY):
123+
with pytest.raises(AssertionError, match="array-ness does not match"):
124+
func(a, b, check_scalar=True)
125+
else:
126+
func(a, b, check_scalar=True)
127+
with pytest.raises(AssertionError, match="Mismatched elements"):
128+
func(a, c, check_scalar=True)
129+
130+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
131+
@pytest.mark.parametrize("dtype", ["int64", "float64"])
132+
def test_assert_close_tolerance(self, dtype: str, xp: ModuleType):
133+
a = xp.asarray([100], dtype=getattr(xp, dtype))
134+
b = xp.asarray([102], dtype=getattr(xp, dtype))
135+
136+
with pytest.raises(AssertionError, match="Mismatched elements"):
137+
xp_assert_close(a, b)
138+
139+
xp_assert_close(a, b, rtol=0.03)
140+
with pytest.raises(AssertionError, match="Mismatched elements"):
141+
xp_assert_close(a, b, rtol=0.01)
142+
143+
xp_assert_close(a, b, atol=3)
144+
with pytest.raises(AssertionError, match="Mismatched elements"):
145+
xp_assert_close(a, b, atol=1)
146+
147+
def test_assert_less(self, xp: ModuleType):
148+
xp_assert_less(xp.asarray(-1), xp.asarray(0))
149+
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
150+
with pytest.raises(AssertionError, match="Mismatched elements"):
151+
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
152+
153+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
154+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
155+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
156+
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
157+
"""On Dask and other lazy backends, test that a shape with NaN's or None's
158+
can be compared to a real shape.
159+
"""
160+
# actual has shape=(None, )
161+
a = xp.asarray([1] if func is xp_assert_less else [2])
162+
a = a[a > 0]
163+
164+
func(a, xp.asarray([2]))
165+
with pytest.raises(AssertionError, match="shapes do not match"):
166+
func(a, xp.asarray(2))
167+
with pytest.raises(AssertionError, match="shapes do not match"):
168+
func(a, xp.asarray([2, 3]))
169+
with pytest.raises(AssertionError, match="Mismatched elements"):
170+
func(a, xp.asarray([0]))
171+
172+
# desired has shape=(None, )
173+
a = xp.asarray([3] if func is xp_assert_less else [2])
174+
a = a[a > 0]
175+
176+
func(xp.asarray([2]), a)
177+
with pytest.raises(AssertionError, match="shapes do not match"):
178+
func(xp.asarray(2), a)
179+
with pytest.raises(AssertionError, match="shapes do not match"):
180+
func(xp.asarray([2, 3]), a)
181+
with pytest.raises(AssertionError, match="Mismatched elements"):
182+
func(xp.asarray([4]), a)
190183

191184

192185
def good_lazy(x: Array) -> Array:

0 commit comments

Comments
 (0)