@@ -21,6 +21,7 @@ def _get_hermitian(xp, a, UPLO):
2121 )
2222)
2323class TestEigenvalue :
24+
2425 @testing .for_all_dtypes ()
2526 @testing .numpy_cupy_allclose (
2627 rtol = 1e-3 ,
@@ -47,9 +48,7 @@ def test_eigh(self, xp, dtype):
4748 tol = 1e-3
4849 else :
4950 tol = 1e-5
50-
5151 testing .assert_allclose (A @ v , v @ xp .diag (w ), atol = tol , rtol = tol )
52-
5352 # Check if v @ vt is an identity matrix
5453 testing .assert_allclose (
5554 v @ v .swapaxes (- 2 , - 1 ).conj (),
@@ -87,7 +86,7 @@ def test_eigh_batched(self, xp, dtype):
8786 )
8887 return w
8988
90- @testing .for_complex_dtypes ( )
89+ @testing .for_dtypes ( "FD" )
9190 @testing .numpy_cupy_allclose (rtol = 1e-3 , atol = 1e-4 )
9291 def test_eigh_complex_batched (self , xp , dtype ):
9392 a = xp .array (
@@ -105,7 +104,6 @@ def test_eigh_complex_batched(self, xp, dtype):
105104 # eigenvectors, so v's are not directly comparable and we verify
106105 # them through the eigen equation A*v=w*v.
107106 A = _get_hermitian (xp , a , self .UPLO )
108-
109107 for i in range (a .shape [0 ]):
110108 testing .assert_allclose (
111109 A [i ].dot (v [i ]), w [i ] * v [i ], rtol = 1e-5 , atol = 1e-5
@@ -165,44 +163,54 @@ def test_eigvalsh_complex_batched(self, xp, dtype):
165163 return w
166164
167165
168- @testing .parameterize (
169- * testing .product (
170- {"UPLO" : ["U" , "L" ], "shape" : [(0 , 0 ), (2 , 0 , 0 ), (0 , 3 , 3 )]}
171- )
166+ @pytest .mark .parametrize ("UPLO" , ["U" , "L" ])
167+ @pytest .mark .parametrize (
168+ "shape" ,
169+ [
170+ (0 , 0 ),
171+ (2 , 0 , 0 ),
172+ (0 , 3 , 3 ),
173+ ],
172174)
173175class TestEigenvalueEmpty :
176+
174177 @testing .for_dtypes ("ifdFD" )
175178 @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
176- def test_eigh (self , xp , dtype ):
177- a = xp .empty (self . shape , dtype = dtype )
179+ def test_eigh (self , xp , dtype , shape , UPLO ):
180+ a = xp .empty (shape , dtype = dtype )
178181 assert a .size == 0
179- return xp .linalg .eigh (a , UPLO = self . UPLO )
182+ return xp .linalg .eigh (a , UPLO = UPLO )
180183
181184 @testing .for_dtypes ("ifdFD" )
182185 @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
183- def test_eigvalsh (self , xp , dtype ):
184- a = xp .empty (self . shape , dtype = dtype )
186+ def test_eigvalsh (self , xp , dtype , shape , UPLO ):
187+ a = xp .empty (shape , dtype = dtype )
185188 assert a .size == 0
186- return xp .linalg .eigvalsh (a , UPLO = self .UPLO )
187-
188-
189- @testing .parameterize (
190- * testing .product (
191- {
192- "UPLO" : ["U" , "L" ],
193- "shape" : [(), (3 ,), (2 , 3 ), (4 , 0 ), (2 , 2 , 3 ), (0 , 2 , 3 )],
194- }
195- )
189+ return xp .linalg .eigvalsh (a , UPLO = UPLO )
190+
191+
192+ @pytest .mark .parametrize ("UPLO" , ["U" , "L" ])
193+ @pytest .mark .parametrize (
194+ "shape" ,
195+ [
196+ (),
197+ (3 ,),
198+ (2 , 3 ),
199+ (4 , 0 ),
200+ (2 , 2 , 3 ),
201+ (0 , 2 , 3 ),
202+ ],
196203)
197204class TestEigenvalueInvalid :
198- def test_eigh_shape_error (self ):
205+
206+ def test_eigh_shape_error (self , UPLO , shape ):
199207 for xp in (numpy , cupy ):
200- a = xp .zeros (self . shape )
208+ a = xp .zeros (shape )
201209 with pytest .raises (xp .linalg .LinAlgError ):
202- xp .linalg .eigh (a , self . UPLO )
210+ xp .linalg .eigh (a , UPLO )
203211
204- def test_eigvalsh_shape_error (self ):
212+ def test_eigvalsh_shape_error (self , UPLO , shape ):
205213 for xp in (numpy , cupy ):
206- a = xp .zeros (self . shape )
214+ a = xp .zeros (shape )
207215 with pytest .raises (xp .linalg .LinAlgError ):
208- xp .linalg .eigvalsh (a , self . UPLO )
216+ xp .linalg .eigvalsh (a , UPLO )
0 commit comments