1515@testing .parameterize (
1616 * testing .product (
1717 {
18+ # "batched_gesv_limit": [None, 0],
1819 "order" : ["C" , "F" ],
1920 }
2021 )
2122)
2223@testing .fix_random ()
2324class TestSolve (unittest .TestCase ):
24- # TODO: add get_batched_gesv_limit
25+
2526 # def setUp(self):
2627 # if self.batched_gesv_limit is not None:
2728 # self.old_limit = get_batched_gesv_limit()
@@ -32,6 +33,7 @@ class TestSolve(unittest.TestCase):
3233 # set_batched_gesv_limit(self.old_limit)
3334
3435 @testing .for_dtypes ("ifdFD" )
36+ # TODO(kataoka): Fix contiguity
3537 @testing .numpy_cupy_allclose (
3638 atol = 1e-3 , contiguous_check = False , type_check = has_support_aspect64 ()
3739 )
@@ -47,6 +49,7 @@ def check_x(self, a_shape, b_shape, xp, dtype):
4749 testing .assert_array_equal (b_copy , b )
4850 return result
4951
52+ @testing .with_requires ("numpy>=2.0" )
5053 def test_solve (self ):
5154 self .check_x ((4 , 4 ), (4 ,))
5255 self .check_x ((5 , 5 ), (5 , 2 ))
@@ -55,15 +58,9 @@ def test_solve(self):
5558 self .check_x ((0 , 0 ), (0 ,))
5659 self .check_x ((0 , 0 ), (0 , 2 ))
5760 self .check_x ((0 , 2 , 2 ), (0 , 2 , 3 ))
58- # In numpy 2.0 the broadcast ambiguity has been removed and now
59- # b is treaded as a single vector if and only if it is 1-dimensional;
60- # for other cases this signature must be followed
61- # (..., m, m), (..., m, n) -> (..., m, n)
62- # https://github.com/numpy/numpy/pull/25914
63- if numpy .lib .NumpyVersion (numpy .__version__ ) < "2.0.0" :
64- self .check_x ((2 , 4 , 4 ), (2 , 4 ))
65- self .check_x ((2 , 3 , 2 , 2 ), (2 , 3 , 2 ))
66- self .check_x ((0 , 2 , 2 ), (0 , 2 ))
61+ # Allowed since numpy 2
62+ self .check_x ((2 , 3 , 3 ), (3 ,))
63+ self .check_x ((2 , 5 , 3 , 3 ), (3 ,))
6764
6865 def check_shape (self , a_shape , b_shape , error_types ):
6966 for xp , error_type in error_types .items ():
@@ -76,12 +73,14 @@ def check_shape(self, a_shape, b_shape, error_types):
7673 # NumPy with OpenBLAS returns an empty array
7774 # while numpy with OneMKL raises LinAlgError
7875 @pytest .mark .skip ("Undefined behavior" )
76+ @testing .numpy_cupy_allclose ()
7977 def test_solve_singular_empty (self , xp ):
8078 a = xp .zeros ((3 , 3 )) # singular
8179 b = xp .empty ((3 , 0 )) # nrhs = 0
8280 # LinAlgError("Singular matrix") is not raised
8381 return xp .linalg .solve (a , b )
8482
83+ @testing .with_requires ("numpy>=2.0" )
8584 def test_invalid_shape (self ):
8685 linalg_errors = {
8786 numpy : numpy .linalg .LinAlgError ,
@@ -96,11 +95,35 @@ def test_invalid_shape(self):
9695 self .check_shape ((3 , 3 ), (2 ,), value_errors )
9796 self .check_shape ((3 , 3 ), (2 , 2 ), value_errors )
9897 self .check_shape ((3 , 3 , 4 ), (3 ,), linalg_errors )
99- # Since numpy >= 2.0, this case does not raise an error
100- if numpy .lib .NumpyVersion (numpy .__version__ ) < "2.0.0" :
101- self .check_shape ((2 , 3 , 3 ), (3 ,), value_errors )
10298 self .check_shape ((3 , 3 ), (0 ,), value_errors )
10399 self .check_shape ((0 , 3 , 4 ), (3 ,), linalg_errors )
100+ self .check_shape ((3 , 3 ), (), value_errors )
101+ # Not allowed since numpy 2
102+ self .check_shape (
103+ (0 , 2 , 2 ),
104+ (
105+ 0 ,
106+ 2 ,
107+ ),
108+ value_errors ,
109+ )
110+ self .check_shape (
111+ (2 , 4 , 4 ),
112+ (
113+ 2 ,
114+ 4 ,
115+ ),
116+ value_errors ,
117+ )
118+ self .check_shape (
119+ (2 , 3 , 2 , 2 ),
120+ (
121+ 2 ,
122+ 3 ,
123+ 2 ,
124+ ),
125+ value_errors ,
126+ )
104127
105128
106129@testing .parameterize (
@@ -113,6 +136,7 @@ def test_invalid_shape(self):
113136)
114137@testing .fix_random ()
115138class TestTensorSolve (unittest .TestCase ):
139+
116140 @testing .for_dtypes ("ifdFD" )
117141 @testing .numpy_cupy_allclose (atol = 0.02 , type_check = has_support_aspect64 ())
118142 def test_tensorsolve (self , xp , dtype ):
@@ -131,6 +155,7 @@ def test_tensorsolve(self, xp, dtype):
131155 )
132156)
133157class TestInv (unittest .TestCase ):
158+
134159 @testing .for_dtypes ("ifdFD" )
135160 @_condition .retry (10 )
136161 def check_x (self , a_shape , dtype ):
@@ -140,7 +165,6 @@ def check_x(self, a_shape, dtype):
140165 a_gpu_copy = a_gpu .copy ()
141166 result_cpu = numpy .linalg .inv (a_cpu )
142167 result_gpu = cupy .linalg .inv (a_gpu )
143-
144168 assert_dtype_allclose (result_gpu , result_cpu )
145169 testing .assert_array_equal (a_gpu_copy , a_gpu )
146170
@@ -170,6 +194,7 @@ def test_invalid_shape(self):
170194
171195
172196class TestInvInvalid (unittest .TestCase ):
197+
173198 @testing .for_dtypes ("ifdFD" )
174199 def test_inv (self , dtype ):
175200 for xp in (numpy , cupy ):
@@ -192,6 +217,7 @@ def test_batched_inv(self, dtype):
192217
193218
194219class TestPinv (unittest .TestCase ):
220+
195221 @testing .for_dtypes ("ifdFD" )
196222 @_condition .retry (10 )
197223 def check_x (self , a_shape , rcond , dtype ):
@@ -234,6 +260,7 @@ def test_pinv_size_0(self):
234260
235261
236262class TestLstsq :
263+
237264 @testing .for_dtypes ("ifdFD" )
238265 @testing .numpy_cupy_allclose (atol = 1e-3 , type_check = has_support_aspect64 ())
239266 def check_lstsq_solution (
@@ -312,20 +339,18 @@ def test_invalid_shapes(self):
312339 self .check_invalid_shapes ((3 , 3 ), (2 , 2 ))
313340 self .check_invalid_shapes ((4 , 3 ), (10 , 3 , 3 ))
314341
315- # dpnp.linalg.lstsq() does not raise a FutureWarning
316- # because dpnp did not have a previous implementation of dpnp.linalg.lstsq()
317- # and there is no need to get rid of old deprecated behavior as numpy did.
318- @pytest .mark .skip ("No support of deprecated behavior" )
342+ @testing .with_requires ("numpy>=2.0" )
319343 @testing .for_float_dtypes (no_float16 = True )
320344 @testing .numpy_cupy_allclose (atol = 1e-3 )
321- def test_warn_rcond (self , xp , dtype ):
345+ def test_nowarn_rcond (self , xp , dtype ):
322346 a = testing .shaped_random ((3 , 3 ), xp , dtype )
323347 b = testing .shaped_random ((3 ,), xp , dtype )
324- with testing . assert_warns ( FutureWarning ):
325- return xp .linalg .lstsq (a , b )
348+ # FutureWarning is no longer emitted
349+ return xp .linalg .lstsq (a , b )
326350
327351
328352class TestTensorInv (unittest .TestCase ):
353+
329354 @testing .for_dtypes ("ifdFD" )
330355 @_condition .retry (10 )
331356 def check_x (self , a_shape , ind , dtype ):
0 commit comments