Skip to content

Commit eb12f2d

Browse files
committed
Update linalg_tests/test_eigenvalue.py
1 parent 94189b8 commit eb12f2d

File tree

1 file changed

+37
-29
lines changed

1 file changed

+37
-29
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_eigenvalue.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def _get_hermitian(xp, a, UPLO):
2121
)
2222
)
2323
class TestEigenvalue:
24+
2425
@testing.for_all_dtypes()
2526
@testing.numpy_cupy_allclose(
2627
rtol=1e-3,
@@ -47,9 +48,7 @@ def test_eigh(self, xp, dtype):
4748
tol = 1e-3
4849
else:
4950
tol = 1e-5
50-
5151
testing.assert_allclose(A @ v, v @ xp.diag(w), atol=tol, rtol=tol)
52-
5352
# Check if v @ vt is an identity matrix
5453
testing.assert_allclose(
5554
v @ v.swapaxes(-2, -1).conj(),
@@ -87,7 +86,7 @@ def test_eigh_batched(self, xp, dtype):
8786
)
8887
return w
8988

90-
@testing.for_complex_dtypes()
89+
@testing.for_dtypes("FD")
9190
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-4)
9291
def test_eigh_complex_batched(self, xp, dtype):
9392
a = xp.array(
@@ -105,7 +104,6 @@ def test_eigh_complex_batched(self, xp, dtype):
105104
# eigenvectors, so v's are not directly comparable and we verify
106105
# them through the eigen equation A*v=w*v.
107106
A = _get_hermitian(xp, a, self.UPLO)
108-
109107
for i in range(a.shape[0]):
110108
testing.assert_allclose(
111109
A[i].dot(v[i]), w[i] * v[i], rtol=1e-5, atol=1e-5
@@ -165,44 +163,54 @@ def test_eigvalsh_complex_batched(self, xp, dtype):
165163
return w
166164

167165

168-
@testing.parameterize(
169-
*testing.product(
170-
{"UPLO": ["U", "L"], "shape": [(0, 0), (2, 0, 0), (0, 3, 3)]}
171-
)
166+
@pytest.mark.parametrize("UPLO", ["U", "L"])
167+
@pytest.mark.parametrize(
168+
"shape",
169+
[
170+
(0, 0),
171+
(2, 0, 0),
172+
(0, 3, 3),
173+
],
172174
)
173175
class TestEigenvalueEmpty:
176+
174177
@testing.for_dtypes("ifdFD")
175178
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
176-
def test_eigh(self, xp, dtype):
177-
a = xp.empty(self.shape, dtype=dtype)
179+
def test_eigh(self, xp, dtype, shape, UPLO):
180+
a = xp.empty(shape, dtype=dtype)
178181
assert a.size == 0
179-
return xp.linalg.eigh(a, UPLO=self.UPLO)
182+
return xp.linalg.eigh(a, UPLO=UPLO)
180183

181184
@testing.for_dtypes("ifdFD")
182185
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
183-
def test_eigvalsh(self, xp, dtype):
184-
a = xp.empty(self.shape, dtype=dtype)
186+
def test_eigvalsh(self, xp, dtype, shape, UPLO):
187+
a = xp.empty(shape, dtype=dtype)
185188
assert a.size == 0
186-
return xp.linalg.eigvalsh(a, UPLO=self.UPLO)
187-
188-
189-
@testing.parameterize(
190-
*testing.product(
191-
{
192-
"UPLO": ["U", "L"],
193-
"shape": [(), (3,), (2, 3), (4, 0), (2, 2, 3), (0, 2, 3)],
194-
}
195-
)
189+
return xp.linalg.eigvalsh(a, UPLO=UPLO)
190+
191+
192+
@pytest.mark.parametrize("UPLO", ["U", "L"])
193+
@pytest.mark.parametrize(
194+
"shape",
195+
[
196+
(),
197+
(3,),
198+
(2, 3),
199+
(4, 0),
200+
(2, 2, 3),
201+
(0, 2, 3),
202+
],
196203
)
197204
class TestEigenvalueInvalid:
198-
def test_eigh_shape_error(self):
205+
206+
def test_eigh_shape_error(self, UPLO, shape):
199207
for xp in (numpy, cupy):
200-
a = xp.zeros(self.shape)
208+
a = xp.zeros(shape)
201209
with pytest.raises(xp.linalg.LinAlgError):
202-
xp.linalg.eigh(a, self.UPLO)
210+
xp.linalg.eigh(a, UPLO)
203211

204-
def test_eigvalsh_shape_error(self):
212+
def test_eigvalsh_shape_error(self, UPLO, shape):
205213
for xp in (numpy, cupy):
206-
a = xp.zeros(self.shape)
214+
a = xp.zeros(shape)
207215
with pytest.raises(xp.linalg.LinAlgError):
208-
xp.linalg.eigvalsh(a, self.UPLO)
216+
xp.linalg.eigvalsh(a, UPLO)

0 commit comments

Comments
 (0)