Skip to content

Commit dd6cb8c

Browse files
Improve eig tests in preparation for new eig backends
- Verify validity of eigenvectors using the eigen decomposition identity for improved robustness, as eigenvectors are not unique. - Increases test reliability across backends (cuSOLVER, MAGMA, CPU). - Tolerances derived from numerical comparisons between cuSOLVER and NumPy. See discussion: https://dev-discuss.pytorch.org/t/cusolver-dnxgeev-faster-cuda-eigenvalue-calculations/3248/6
1 parent 8cee9e6 commit dd6cb8c

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

test/test_linalg.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2139,6 +2139,34 @@ def run_test(shape, *, symmetric=False):
21392139
@skipCUDAIfNoMagma
21402140
@dtypes(*floating_and_complex_types())
21412141
def test_eig_compare_backends(self, device, dtype):
2142+
2143+
def fulfills_eigen_decomposition_identity(a, eig, dtype):
2144+
2145+
# check correctness using eigendecomposition identity
2146+
if dtype in [torch.float32, torch.complex64]:
2147+
atol = 1e-3 # CUDA seems to give less accurate results for float32 and complex64. They are about one to two OOM from NumPy results
2148+
else:
2149+
atol = 1e-13 # Same OOM for NumPy
2150+
2151+
2152+
w, v = eig
2153+
2154+
# move all tensors to CPU to avoid issues with GPU matmul
2155+
a = a.cpu().to(v.dtype)
2156+
v = v.to("cpu")
2157+
w = w.to("cpu")
2158+
2159+
if a.numel() == 0 and v.numel() == 0 and w.numel() == 0:
2160+
return True
2161+
elif a.numel() == 0 or v.numel() == 0 or w.numel() == 0:
2162+
return False
2163+
2164+
diff = (a @ v) - (v * w.unsqueeze(-2))
2165+
diff = diff.abs()
2166+
diff = torch.max(diff)
2167+
print(f"diff: {diff}")
2168+
return diff <= atol
2169+
21422170
def run_test(shape, *, symmetric=False):
21432171
from torch.testing._internal.common_utils import random_symmetric_matrix
21442172

@@ -2155,7 +2183,7 @@ def run_test(shape, *, symmetric=False):
21552183
# compare with CPU
21562184
expected = torch.linalg.eig(a.to(complementary_device))
21572185
self.assertEqual(expected[0], actual[0])
2158-
self.assertEqual(expected[1], actual[1])
2186+
self.assertTrue(fulfills_eigen_decomposition_identity(a, actual, dtype)) # check evs using eigen identity
21592187

21602188
shapes = [(0, 0), # Empty matrix
21612189
(5, 5), # Single matrix
@@ -2345,6 +2373,8 @@ def run_test(shape, *, symmetric=False):
23452373
run_test(shape)
23462374
run_test(shape, symmetric=True)
23472375

2376+
2377+
23482378
@skipCUDAIfNoMagma
23492379
@skipCPUIfNoLapack
23502380
@dtypes(*floating_and_complex_types())

0 commit comments

Comments
 (0)