1+ import sys
12import unittest
3+ import warnings
24
35import numpy
46import pytest
3739 )
3840)
3941class TestDot (unittest .TestCase ):
42+
4043 @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ])
4144 @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
4245 def test_dot (self , xp , dtype_a , dtype_b ):
@@ -87,32 +90,97 @@ def test_dot_with_out(self, xp, dtype_a, dtype_b, dtype_c):
8790 # Test for 0 dimension
8891 ((3 ,), (3 ,), - 1 , - 1 , - 1 ),
8992 # Test for basic cases
90- ((1 , 2 ), (1 , 2 ), - 1 , - 1 , 1 ),
9193 ((1 , 3 ), (1 , 3 ), 1 , - 1 , - 1 ),
94+ # Test for higher dimensions
95+ ((2 , 4 , 5 , 3 ), (2 , 4 , 5 , 3 ), - 1 , - 1 , 0 ),
96+ ],
97+ }
98+ )
99+ )
100+ class TestCrossProduct (unittest .TestCase ):
101+
102+ @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ])
103+ @testing .numpy_cupy_allclose ()
104+ def test_cross (self , xp , dtype_a , dtype_b ):
105+ if dtype_a == dtype_b == numpy .bool_ :
106+ # cross does not support bool-bool inputs.
107+ return xp .array (True )
108+ shape_a , shape_b , axisa , axisb , axisc = self .params
109+ a = testing .shaped_arange (shape_a , xp , dtype_a )
110+ b = testing .shaped_arange (shape_b , xp , dtype_b )
111+ return xp .cross (a , b , axisa , axisb , axisc )
112+
113+
114+ # XXX: cross with 2D vectors is deprecated in NumPy 2.0, also CuPy 1.14
115+ @testing .parameterize (
116+ * testing .product (
117+ {
118+ "params" : [
119+ # Test for basic cases
120+ ((1 , 2 ), (1 , 2 ), - 1 , - 1 , 1 ),
92121 ((1 , 2 ), (1 , 3 ), - 1 , - 1 , 1 ),
93122 ((2 , 2 ), (1 , 3 ), - 1 , - 1 , 0 ),
94123 ((3 , 3 ), (1 , 2 ), 0 , - 1 , - 1 ),
95124 ((0 , 3 ), (0 , 3 ), - 1 , - 1 , - 1 ),
96125 # Test for higher dimensions
97126 ((2 , 0 , 3 ), (2 , 0 , 3 ), 0 , 0 , 0 ),
98- ((2 , 4 , 5 , 3 ), (2 , 4 , 5 , 3 ), - 1 , - 1 , 0 ),
99127 ((2 , 4 , 5 , 2 ), (2 , 4 , 5 , 2 ), 0 , 0 , - 1 ),
100128 ],
101129 }
102130 )
103131)
104- class TestCrossProduct (unittest .TestCase ):
105- @pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
132+ class TestCrossProductDeprecated (unittest .TestCase ):
106133 @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ])
107- @testing .numpy_cupy_allclose (type_check = has_support_aspect64 () )
134+ @testing .numpy_cupy_allclose ()
108135 def test_cross (self , xp , dtype_a , dtype_b ):
109136 if dtype_a == dtype_b == numpy .bool_ :
110137 # cross does not support bool-bool inputs.
111138 return xp .array (True )
112139 shape_a , shape_b , axisa , axisb , axisc = self .params
113140 a = testing .shaped_arange (shape_a , xp , dtype_a )
114141 b = testing .shaped_arange (shape_b , xp , dtype_b )
115- return xp .cross (a , b , axisa , axisb , axisc )
142+
143+ with warnings .catch_warnings ():
144+ warnings .simplefilter ("ignore" , DeprecationWarning )
145+ res = xp .cross (a , b , axisa , axisb , axisc )
146+ return res
147+
148+
149+ @testing .parameterize (
150+ * testing .product (
151+ {
152+ "params" : [
153+ # Test for 0 dimension
154+ (
155+ (3 ,),
156+ (3 ,),
157+ - 1 ,
158+ ),
159+ # Test for basic cases
160+ (
161+ (1 , 3 ),
162+ (1 , 3 ),
163+ 1 ,
164+ ),
165+ # Test for higher dimensions
166+ ((2 , 4 , 5 , 3 ), (2 , 4 , 5 , 3 ), - 1 ),
167+ ],
168+ }
169+ )
170+ )
171+ class TestLinalgCrossProduct (unittest .TestCase ):
172+
173+ @testing .with_requires ("numpy>=2.0" )
174+ @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ])
175+ @testing .numpy_cupy_allclose ()
176+ def test_cross (self , xp , dtype_a , dtype_b ):
177+ if dtype_a == dtype_b == numpy .bool_ :
178+ # cross does not support bool-bool inputs.
179+ return xp .array (True )
180+ shape_a , shape_b , axis = self .params
181+ a = testing .shaped_arange (shape_a , xp , dtype_a )
182+ b = testing .shaped_arange (shape_b , xp , dtype_b )
183+ return xp .linalg .cross (a , b , axis = axis )
116184
117185
118186@testing .parameterize (
@@ -129,6 +197,7 @@ def test_cross(self, xp, dtype_a, dtype_b):
129197 )
130198)
131199class TestDotFor0Dim (unittest .TestCase ):
200+
132201 @testing .for_all_dtypes_combination (["dtype_a" , "dtype_b" ])
133202 @testing .numpy_cupy_allclose (
134203 type_check = has_support_aspect64 (), contiguous_check = False
@@ -147,6 +216,7 @@ def test_dot(self, xp, dtype_a, dtype_b):
147216
148217
149218class TestProduct :
219+
150220 @testing .for_all_dtypes ()
151221 @testing .numpy_cupy_allclose ()
152222 def test_dot_vec1 (self , xp , dtype ):
@@ -403,7 +473,9 @@ def test_zerodim_kron(self, xp, dtype):
403473 )
404474 @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
405475 def test_kron_accepts_numbers_as_arguments (self , a , b , xp ):
406- args = [xp .array (arg ) if type (arg ) == list else arg for arg in [a , b ]]
476+ args = [
477+ xp .array (arg ) if isinstance (arg , list ) else arg for arg in [a , b ]
478+ ]
407479 return xp .kron (* args )
408480
409481
@@ -422,6 +494,7 @@ def test_kron_accepts_numbers_as_arguments(self, a, b, xp):
422494 )
423495)
424496class TestProductZeroLength (unittest .TestCase ):
497+
425498 @testing .for_all_dtypes ()
426499 @testing .numpy_cupy_allclose ()
427500 def test_tensordot_zero_length (self , xp , dtype ):
@@ -488,9 +561,13 @@ def test_matrix_power_large(self, xp, dtype):
488561 a = xp .eye (23 , k = 17 , dtype = dtype ) + xp .eye (23 , k = - 6 , dtype = dtype )
489562 return xp .linalg .matrix_power (a , 123456789123456789 )
490563
564+ @pytest .mark .skipif (
565+ sys .platform == "win32" , reason = "python int overflows C long"
566+ )
491567 @testing .for_float_dtypes (no_float16 = True )
492568 @testing .numpy_cupy_allclose ()
493569 def test_matrix_power_invlarge (self , xp , dtype ):
570+ # TODO (ev-br): np 2.0: check if it's fixed in numpy 2 (broken on 1.26)
494571 a = xp .eye (23 , k = 17 , dtype = dtype ) + xp .eye (23 , k = - 6 , dtype = dtype )
495572 return xp .linalg .matrix_power (a , - 987654321987654321 )
496573
@@ -504,6 +581,7 @@ def test_matrix_power_invlarge(self, xp, dtype):
504581)
505582@pytest .mark .parametrize ("n" , [0 , 5 , - 7 ])
506583class TestMatrixPowerBatched :
584+
507585 @testing .for_float_dtypes (no_float16 = True )
508586 @testing .numpy_cupy_allclose (rtol = 5e-5 )
509587 def test_matrix_power_batched (self , xp , dtype , shape , n ):
0 commit comments