Skip to content

Commit fafce08

Browse files
author
Vahid Tavanashad
committed
add no_int8 kwarg
1 parent ece0855 commit fafce08

File tree

5 files changed

+41
-92
lines changed

5 files changed

+41
-92
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_einsum.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,11 @@ def test_scalar_float(self, xp, dtype):
464464
)
465465
)
466466
class TestEinSumBinaryOperation:
467-
@testing.for_all_dtypes_combination(
468-
["dtype_a", "dtype_b"], no_bool=False, no_float16=False
469-
)
467+
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"], no_int8=True)
470468
@testing.numpy_cupy_allclose(
471469
type_check=has_support_aspect64(), contiguous_check=False
472470
)
473471
def test_einsum_binary(self, xp, dtype_a, dtype_b):
474-
if all(dtype in [xp.int8, xp.uint8] for dtype in [dtype_a, dtype_b]):
475-
pytest.skip("avoid overflow")
476472
a = testing.shaped_arange(self.shape_a, xp, dtype_a)
477473
b = testing.shaped_arange(self.shape_b, xp, dtype_b)
478474
# casting should be added for dpnp to allow cast int64 to float32
@@ -558,9 +554,7 @@ def test_scalar_2(self, xp, dtype):
558554
)
559555
class TestEinSumTernaryOperation:
560556

561-
@testing.for_all_dtypes_combination(
562-
["dtype_a", "dtype_b", "dtype_c"], no_bool=False, no_float16=False
563-
)
557+
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b", "dtype_c"])
564558
@testing.numpy_cupy_allclose(
565559
type_check=has_support_aspect64(), contiguous_check=False
566560
)

dpnp/tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,9 @@
4040
)
4141
class TestDot(unittest.TestCase):
4242

43-
# Avoid overflow
44-
skip_dtypes = {
45-
(numpy.int8, numpy.int8),
46-
(numpy.int8, numpy.uint8),
47-
(numpy.uint8, numpy.uint8),
48-
}
49-
50-
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
43+
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"], no_int8=True)
5144
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
5245
def test_dot(self, xp, dtype_a, dtype_b):
53-
if (dtype_a, dtype_b) in self.skip_dtypes or (
54-
dtype_b,
55-
dtype_a,
56-
) in self.skip_dtypes:
57-
pytest.skip("avoid overflow")
5846
shape_a, shape_b = self.shape
5947
if self.trans_a:
6048
a = testing.shaped_arange(shape_a[::-1], xp, dtype_a).T
@@ -250,20 +238,16 @@ def test_dot_vec3(self, xp, dtype):
250238
b = testing.shaped_arange((2,), xp, dtype)
251239
return xp.dot(a, b)
252240

253-
@testing.for_all_dtypes()
241+
@testing.for_all_dtypes(no_int8=True)
254242
@testing.numpy_cupy_allclose()
255243
def test_transposed_dot(self, xp, dtype):
256-
if dtype in [numpy.int8, numpy.uint8]:
257-
pytest.skip("avoid overflow")
258244
a = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(1, 0, 2)
259245
b = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(0, 2, 1)
260246
return xp.dot(a, b)
261247

262-
@testing.for_all_dtypes()
248+
@testing.for_all_dtypes(no_int8=True)
263249
@testing.numpy_cupy_allclose()
264250
def test_transposed_dot_with_out(self, xp, dtype):
265-
if dtype in [numpy.int8, numpy.uint8]:
266-
pytest.skip("avoid overflow")
267251
a = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(1, 0, 2)
268252
b = testing.shaped_arange((4, 2, 3), xp, dtype).transpose(2, 0, 1)
269253
c = xp.ndarray((3, 2, 3, 2), dtype=dtype)
@@ -336,20 +320,16 @@ def test_reversed_inner(self, xp, dtype):
336320
b = testing.shaped_reverse_arange((5,), xp, dtype)[::-1]
337321
return xp.inner(a, b)
338322

339-
@testing.for_all_dtypes()
323+
@testing.for_all_dtypes(no_int8=True)
340324
@testing.numpy_cupy_allclose()
341325
def test_multidim_inner(self, xp, dtype):
342-
if dtype in [numpy.int8, numpy.uint8]:
343-
pytest.skip("avoid overflow")
344326
a = testing.shaped_arange((2, 3, 4), xp, dtype)
345327
b = testing.shaped_arange((3, 2, 4), xp, dtype)
346328
return xp.inner(a, b)
347329

348-
@testing.for_all_dtypes()
330+
@testing.for_all_dtypes(no_int8=True)
349331
@testing.numpy_cupy_allclose()
350332
def test_transposed_higher_order_inner(self, xp, dtype):
351-
if dtype in [numpy.int8, numpy.uint8]:
352-
pytest.skip("avoid overflow")
353333
a = testing.shaped_arange((2, 4, 3), xp, dtype).transpose(2, 0, 1)
354334
b = testing.shaped_arange((4, 2, 3), xp, dtype).transpose(1, 2, 0)
355335
return xp.inner(a, b)
@@ -375,20 +355,16 @@ def test_multidim_outer(self, xp, dtype):
375355
b = testing.shaped_arange((4, 5), xp, dtype)
376356
return xp.outer(a, b)
377357

378-
@testing.for_all_dtypes()
358+
@testing.for_all_dtypes(no_int8=True)
379359
@testing.numpy_cupy_allclose()
380360
def test_tensordot(self, xp, dtype):
381-
if dtype in [numpy.int8, numpy.uint8]:
382-
pytest.skip("avoid overflow")
383361
a = testing.shaped_arange((2, 3, 4), xp, dtype)
384362
b = testing.shaped_arange((3, 4, 5), xp, dtype)
385363
return xp.tensordot(a, b)
386364

387-
@testing.for_all_dtypes()
365+
@testing.for_all_dtypes(no_int8=True)
388366
@testing.numpy_cupy_allclose()
389367
def test_transposed_tensordot(self, xp, dtype):
390-
if dtype in [numpy.int8, numpy.uint8]:
391-
pytest.skip("avoid overflow")
392368
a = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(1, 0, 2)
393369
b = testing.shaped_arange((4, 3, 2), xp, dtype).transpose(2, 0, 1)
394370
return xp.tensordot(a, b)
@@ -540,19 +516,15 @@ def test_matrix_power_1(self, xp, dtype):
540516
a = testing.shaped_arange((3, 3), xp, dtype)
541517
return xp.linalg.matrix_power(a, 1)
542518

543-
@testing.for_all_dtypes()
519+
@testing.for_all_dtypes(no_int8=True)
544520
@testing.numpy_cupy_allclose()
545521
def test_matrix_power_2(self, xp, dtype):
546-
if dtype in [numpy.int8, numpy.uint8]:
547-
pytest.skip("avoid overflow")
548522
a = testing.shaped_arange((3, 3), xp, dtype)
549523
return xp.linalg.matrix_power(a, 2)
550524

551-
@testing.for_all_dtypes()
525+
@testing.for_all_dtypes(no_int8=True)
552526
@testing.numpy_cupy_allclose()
553527
def test_matrix_power_3(self, xp, dtype):
554-
if dtype in [numpy.int8, numpy.uint8]:
555-
pytest.skip("avoid overflow")
556528
a = testing.shaped_arange((3, 3), xp, dtype)
557529
return xp.linalg.matrix_power(a, 3)
558530

dpnp/tests/third_party/cupy/logic_tests/test_comparison.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,9 @@ def test_allclose_array_scalar(self, xp, dtype):
242242

243243
class TestIsclose(unittest.TestCase):
244244

245-
@testing.for_all_dtypes(no_complex=True)
245+
@testing.for_all_dtypes(no_complex=True, no_int8=True)
246246
@testing.numpy_cupy_array_equal()
247247
def test_is_close_finite(self, xp, dtype):
248-
if dtype in [xp.int8, xp.uint8]:
249-
pytest.skip("avoid overflow")
250248
# In numpy<1.10 this test fails when dtype is bool
251249
a = xp.array([0.9e-5, 1.1e-5, 1000 + 1e-4, 1000 - 1e-4]).astype(dtype)
252250
b = xp.array([0, 0, 1000, 1000]).astype(dtype)

dpnp/tests/third_party/cupy/math_tests/test_matmul.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,39 +60,22 @@
6060
)
6161
class TestMatmul(unittest.TestCase):
6262

63-
# Avoid overflow
64-
skip_dtypes = {
65-
(numpy.int8, numpy.int8),
66-
(numpy.int8, numpy.uint8),
67-
(numpy.uint8, numpy.uint8),
68-
}
69-
70-
@testing.for_all_dtypes(name="dtype1")
71-
@testing.for_all_dtypes(name="dtype2")
63+
@testing.for_all_dtypes(name="dtype1", no_int8=True)
64+
@testing.for_all_dtypes(name="dtype2", no_int8=True)
7265
@testing.numpy_cupy_allclose(
7366
rtol=1e-3, atol=1e-3, type_check=has_support_aspect64()
7467
) # required for uint8
7568
def test_operator_matmul(self, xp, dtype1, dtype2):
76-
if (dtype1, dtype2) in self.skip_dtypes or (
77-
dtype2,
78-
dtype1,
79-
) in self.skip_dtypes:
80-
pytest.skip("avoid overflow")
8169
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
8270
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
8371
return operator.matmul(x1, x2)
8472

85-
@testing.for_all_dtypes(name="dtype1")
86-
@testing.for_all_dtypes(name="dtype2")
73+
@testing.for_all_dtypes(name="dtype1", no_int8=True)
74+
@testing.for_all_dtypes(name="dtype2", no_int8=True)
8775
@testing.numpy_cupy_allclose(
8876
rtol=1e-3, atol=1e-3, type_check=has_support_aspect64()
8977
) # required for uint8
9078
def test_cupy_matmul(self, xp, dtype1, dtype2):
91-
if (dtype1, dtype2) in self.skip_dtypes or (
92-
dtype2,
93-
dtype1,
94-
) in self.skip_dtypes:
95-
pytest.skip("avoid overflow")
9679
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
9780
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
9881
return xp.matmul(x1, x2)
@@ -114,25 +97,12 @@ def test_cupy_matmul(self, xp, dtype1, dtype2):
11497
)
11598
class TestMatmulOut(unittest.TestCase):
11699

117-
# Avoid overflow
118-
skip_dtypes = {
119-
(numpy.int8, numpy.int8),
120-
(numpy.int8, numpy.uint8),
121-
(numpy.uint8, numpy.uint8),
122-
}
123-
124-
@testing.for_all_dtypes(name="dtype1")
125-
@testing.for_all_dtypes(name="dtype2")
100+
@testing.for_all_dtypes(name="dtype1", no_int8=True)
101+
@testing.for_all_dtypes(name="dtype2", no_int8=True)
126102
@testing.numpy_cupy_allclose(
127103
rtol=1e-3, atol=1e-3, accept_error=TypeError # required for uint8
128104
)
129105
def test_cupy_matmul_noncontiguous(self, xp, dtype1, dtype2):
130-
if (dtype1, dtype2) in self.skip_dtypes or (
131-
dtype2,
132-
dtype1,
133-
) in self.skip_dtypes:
134-
pytest.skip("avoid overflow")
135-
136106
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
137107
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
138108
out = xp.zeros(self.shape_pair[2], dtype=dtype1)[::-1]
@@ -170,11 +140,9 @@ def test_overlap_both(self, xp, dtype, shape):
170140

171141
class TestMatmulStrides:
172142

173-
@testing.for_all_dtypes()
143+
@testing.for_all_dtypes(no_int8=True)
174144
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8
175145
def test_relaxed_c_contiguous_input(self, xp, dtype):
176-
if dtype in [numpy.int8, numpy.uint8]:
177-
pytest.skip("avoid overflow")
178146
x1 = testing.shaped_arange((2, 2, 3), xp, dtype)[:, None, :, :]
179147
x2 = testing.shaped_arange((2, 1, 3, 1), xp, dtype)
180148
return x1 @ x2

dpnp/tests/third_party/cupy/testing/_loops.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ def _get_int_bool_dtypes():
10911091
_dtypes = _float_dtypes + _int_bool_dtypes
10921092

10931093

1094-
def _make_all_dtypes(no_float16, no_bool, no_complex):
1094+
def _make_all_dtypes(no_float16, no_bool, no_complex, no_int8):
10951095
if no_float16:
10961096
dtypes = _regular_float_dtypes
10971097
else:
@@ -1102,14 +1102,24 @@ def _make_all_dtypes(no_float16, no_bool, no_complex):
11021102
else:
11031103
dtypes += _int_bool_dtypes
11041104

1105+
if no_int8:
1106+
_dtypes = list(dtypes)
1107+
_dtypes.remove(numpy.int8)
1108+
_dtypes.remove(numpy.uint8)
1109+
dtypes = tuple(_dtypes)
1110+
11051111
if config.complex_types and not no_complex:
11061112
dtypes += _complex_dtypes
11071113

11081114
return dtypes
11091115

11101116

11111117
def for_all_dtypes(
1112-
name="dtype", no_float16=False, no_bool=False, no_complex=False
1118+
name="dtype",
1119+
no_float16=False,
1120+
no_bool=False,
1121+
no_complex=False,
1122+
no_int8=False,
11131123
):
11141124
"""Decorator that checks the fixture with all dtypes.
11151125
@@ -1121,6 +1131,9 @@ def for_all_dtypes(
11211131
omitted from candidate dtypes.
11221132
no_complex(bool): If ``True``, ``numpy.complex64`` and
11231133
``numpy.complex128`` are omitted from candidate dtypes.
1134+
no_int8(bool): If ``True``, ``numpy.int8`` and
1135+
``numpy.uint8`` are omitted from candidate dtypes.
1136+
This option is generally used to avoid overflow.
11241137
11251138
dtypes to be tested: ``numpy.complex64`` (optional),
11261139
``numpy.complex128`` (optional),
@@ -1164,7 +1177,7 @@ def for_all_dtypes(
11641177
.. seealso:: :func:`cupy.testing.for_dtypes`
11651178
"""
11661179
return for_dtypes(
1167-
_make_all_dtypes(no_float16, no_bool, no_complex), name=name
1180+
_make_all_dtypes(no_float16, no_bool, no_complex, no_int8), name=name
11681181
)
11691182

11701183

@@ -1334,6 +1347,7 @@ def for_all_dtypes_combination(
13341347
no_bool=False,
13351348
full=None,
13361349
no_complex=False,
1350+
no_int8=False,
13371351
):
13381352
"""Decorator that checks the fixture with a product set of all dtypes.
13391353
@@ -1347,12 +1361,15 @@ def for_all_dtypes_combination(
13471361
will be tested.
13481362
Otherwise, the subset of combinations will be tested
13491363
(see description in :func:`cupy.testing.for_dtypes_combination`).
1350-
no_complex(bool): If, True, ``numpy.complex64`` and
1364+
no_complex(bool): If, ``True``, ``numpy.complex64`` and
13511365
``numpy.complex128`` are omitted from candidate dtypes.
1366+
no_int8(bool): If, ``True``, ``numpy.int8`` and
1367+
``numpy.uint8`` are omitted from candidate dtypes.
1368+
This option is generally used to avoid overflow.
13521369
13531370
.. seealso:: :func:`cupy.testing.for_dtypes_combination`
13541371
"""
1355-
types = _make_all_dtypes(no_float16, no_bool, no_complex)
1372+
types = _make_all_dtypes(no_float16, no_bool, no_complex, no_int8)
13561373
return for_dtypes_combination(types, names, full)
13571374

13581375

0 commit comments

Comments
 (0)