4040)
4141class TestDot (unittest .TestCase ):
4242
43- # Avoid overflow
44- skip_dtypes = {
45- (numpy .int8 , numpy .int8 ),
46- (numpy .int8 , numpy .uint8 ),
47- (numpy .uint8 , numpy .uint8 ),
48- }
49-
50- @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ])
43+ @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ], no_int8 = True )
5144 @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
5245 def test_dot (self , xp , dtype_a , dtype_b ):
53- if (dtype_a , dtype_b ) in self .skip_dtypes or (
54- dtype_b ,
55- dtype_a ,
56- ) in self .skip_dtypes :
57- pytest .skip ("avoid overflow" )
5846 shape_a , shape_b = self .shape
5947 if self .trans_a :
6048 a = testing .shaped_arange (shape_a [::- 1 ], xp , dtype_a ).T
@@ -250,20 +238,16 @@ def test_dot_vec3(self, xp, dtype):
250238 b = testing .shaped_arange ((2 ,), xp , dtype )
251239 return xp .dot (a , b )
252240
253- @testing .for_all_dtypes ()
241+ @testing .for_all_dtypes (no_int8 = True )
254242 @testing .numpy_cupy_allclose ()
255243 def test_transposed_dot (self , xp , dtype ):
256- if dtype in [numpy .int8 , numpy .uint8 ]:
257- pytest .skip ("avoid overflow" )
258244 a = testing .shaped_arange ((2 , 3 , 4 ), xp , dtype ).transpose (1 , 0 , 2 )
259245 b = testing .shaped_arange ((2 , 3 , 4 ), xp , dtype ).transpose (0 , 2 , 1 )
260246 return xp .dot (a , b )
261247
262- @testing .for_all_dtypes ()
248+ @testing .for_all_dtypes (no_int8 = True )
263249 @testing .numpy_cupy_allclose ()
264250 def test_transposed_dot_with_out (self , xp , dtype ):
265- if dtype in [numpy .int8 , numpy .uint8 ]:
266- pytest .skip ("avoid overflow" )
267251 a = testing .shaped_arange ((2 , 3 , 4 ), xp , dtype ).transpose (1 , 0 , 2 )
268252 b = testing .shaped_arange ((4 , 2 , 3 ), xp , dtype ).transpose (2 , 0 , 1 )
269253 c = xp .ndarray ((3 , 2 , 3 , 2 ), dtype = dtype )
@@ -336,20 +320,16 @@ def test_reversed_inner(self, xp, dtype):
336320 b = testing .shaped_reverse_arange ((5 ,), xp , dtype )[::- 1 ]
337321 return xp .inner (a , b )
338322
339- @testing .for_all_dtypes ()
323+ @testing .for_all_dtypes (no_int8 = True )
340324 @testing .numpy_cupy_allclose ()
341325 def test_multidim_inner (self , xp , dtype ):
342- if dtype in [numpy .int8 , numpy .uint8 ]:
343- pytest .skip ("avoid overflow" )
344326 a = testing .shaped_arange ((2 , 3 , 4 ), xp , dtype )
345327 b = testing .shaped_arange ((3 , 2 , 4 ), xp , dtype )
346328 return xp .inner (a , b )
347329
348- @testing .for_all_dtypes ()
330+ @testing .for_all_dtypes (no_int8 = True )
349331 @testing .numpy_cupy_allclose ()
350332 def test_transposed_higher_order_inner (self , xp , dtype ):
351- if dtype in [numpy .int8 , numpy .uint8 ]:
352- pytest .skip ("avoid overflow" )
353333 a = testing .shaped_arange ((2 , 4 , 3 ), xp , dtype ).transpose (2 , 0 , 1 )
354334 b = testing .shaped_arange ((4 , 2 , 3 ), xp , dtype ).transpose (1 , 2 , 0 )
355335 return xp .inner (a , b )
@@ -375,20 +355,16 @@ def test_multidim_outer(self, xp, dtype):
375355 b = testing .shaped_arange ((4 , 5 ), xp , dtype )
376356 return xp .outer (a , b )
377357
378- @testing .for_all_dtypes ()
358+ @testing .for_all_dtypes (no_int8 = True )
379359 @testing .numpy_cupy_allclose ()
380360 def test_tensordot (self , xp , dtype ):
381- if dtype in [numpy .int8 , numpy .uint8 ]:
382- pytest .skip ("avoid overflow" )
383361 a = testing .shaped_arange ((2 , 3 , 4 ), xp , dtype )
384362 b = testing .shaped_arange ((3 , 4 , 5 ), xp , dtype )
385363 return xp .tensordot (a , b )
386364
387- @testing .for_all_dtypes ()
365+ @testing .for_all_dtypes (no_int8 = True )
388366 @testing .numpy_cupy_allclose ()
389367 def test_transposed_tensordot (self , xp , dtype ):
390- if dtype in [numpy .int8 , numpy .uint8 ]:
391- pytest .skip ("avoid overflow" )
392368 a = testing .shaped_arange ((2 , 3 , 4 ), xp , dtype ).transpose (1 , 0 , 2 )
393369 b = testing .shaped_arange ((4 , 3 , 2 ), xp , dtype ).transpose (2 , 0 , 1 )
394370 return xp .tensordot (a , b )
@@ -540,19 +516,15 @@ def test_matrix_power_1(self, xp, dtype):
540516 a = testing .shaped_arange ((3 , 3 ), xp , dtype )
541517 return xp .linalg .matrix_power (a , 1 )
542518
543- @testing .for_all_dtypes ()
519+ @testing .for_all_dtypes (no_int8 = True )
544520 @testing .numpy_cupy_allclose ()
545521 def test_matrix_power_2 (self , xp , dtype ):
546- if dtype in [numpy .int8 , numpy .uint8 ]:
547- pytest .skip ("avoid overflow" )
548522 a = testing .shaped_arange ((3 , 3 ), xp , dtype )
549523 return xp .linalg .matrix_power (a , 2 )
550524
551- @testing .for_all_dtypes ()
525+ @testing .for_all_dtypes (no_int8 = True )
552526 @testing .numpy_cupy_allclose ()
553527 def test_matrix_power_3 (self , xp , dtype ):
554- if dtype in [numpy .int8 , numpy .uint8 ]:
555- pytest .skip ("avoid overflow" )
556528 a = testing .shaped_arange ((3 , 3 ), xp , dtype )
557529 return xp .linalg .matrix_power (a , 3 )
558530
0 commit comments