Skip to content

Commit dea9b90

Browse files
nikitavedpytorchmergebot
authored andcommitted
[MAGMA][CUDA] eig: deprecate MAGMA and dispatch to cuSOLVER unconditionally (pytorch#173510)
As per title. Benchmark script: ```python import torch import torch.utils.benchmark as benchmark from itertools import product results = [] batches = [(), (16,), (64,)] sizes = [16, 128, 512, 2048] for b, n in product(batches, sizes): shape = b + (n, n) print(f"Testing shape={shape}") label = "torch.eig" sub_label = f"{shape}" x = torch.rand(*shape, device="cuda") x = x + x.mH stmt = "torch.linalg.eig(x)" for backend in ("magma", "cusolver"): torch.backends.cuda.preferred_linalg_library(backend) # warm-up for _ in range(5): exec(stmt) results.append(benchmark.Timer( stmt=stmt, globals={'x': x}, label=label, sub_label=sub_label, description=backend, ).blocked_autorange(min_run_time=1)) compare = benchmark.Compare(results) compare.print() ``` Benchmark results (H100): ``` [-------------------- torch.eig -------------------] | magma | cusolver | cusolver speedup 1 threads: ------------------------------------------ (16, 16) | 67215.3 | 893.3 | 75.24 (128, 128) | 390101.4 | 10526.9 | 37.05 (512, 512) | 1614706.1 | 61347.5 | 26.32 (2048, 2048) | 6447301.5 | 356181.2 | 18.10 (16, 16, 16) | 660036.1 | 13095.2 | 50.40 (16, 128, 128) | 6530718.7 | 166647.0 | 39.18 (16, 512, 512) | 20375827.4 | 994115.1 | 20.49 (16, 2048, 2048) | 98335490.6 | 5717112.5 | 17.20 (64, 16, 16) | 2167358.2 | 51977.5 | 41.69 (64, 128, 128) | 25925259.8 | 664574.7 | 39.01 (64, 512, 512) | 84731703.1 | 3946917.0 | 21.46 (64, 2048, 2048) | 380878661.3 | 23008593.5 | 16.55 Times are in microseconds (us). ``` Pull Request resolved: pytorch#173510 Approved by: https://github.com/Skylion007
1 parent c0b01aa commit dea9b90

File tree

4 files changed

+42
-38
lines changed

4 files changed

+42
-38
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,15 @@ void magmaSyevd(
150150
value_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, value_t* rwork,
151151
magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info);
152152

153+
#ifdef USE_ROCM
153154
template<class scalar_t, class value_t=scalar_t>
154155
void magmaEig(
155156
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, scalar_t *A, magma_int_t lda,
156157
scalar_t *w, scalar_t *VL, magma_int_t ldvl,
157158
scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork,
158159
value_t *rwork,
159160
magma_int_t *info);
161+
#endif
160162

161163
template<class scalar_t>
162164
void magmaLuSolve(
@@ -721,6 +723,7 @@ void magmaSyevd<c10::complex<float>, float>(
721723
AT_CUDA_CHECK(cudaGetLastError());
722724
}
723725

726+
#ifdef USE_ROCM
724727
template<>
725728
void magmaEig<double>(
726729
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n,
@@ -800,6 +803,7 @@ void magmaEig<c10::complex<float>, float>(
800803
rwork, info);
801804
AT_CUDA_CHECK(cudaGetLastError());
802805
}
806+
#endif
803807

804808
template<>
805809
void magmaLuSolve<double>(
@@ -971,14 +975,17 @@ magma_trans_t to_magma(TransposeType trans) {
971975

972976
namespace {
973977

974-
void _warn_once_magma_deprecation(const std::string& op_name) {
978+
void _warn_once_magma_deprecation(const std::string& op_name, bool force_cusolver = true) {
975979
if (at::globalContext().linalgPreferredBackend() == at::LinalgBackend::Magma) {
980+
std::string warn_force_cusolver = force_cusolver
981+
? " " + op_name + " will try dispatching to cuSOLVER instead. " +
982+
"If you see any error messages, please, file an issue on GitHub."
983+
: "";
976984
TORCH_WARN_ONCE(
977-
op_name, ": ",
978-
"MAGMA, as a linear algebra backend, is deprecated and will be removed ",
979-
"in future releases. ",
980-
op_name, " will try dispatching to cuSOLVER instead. "
981-
"If you see any error messages, please, file an issue on GitHub."
985+
op_name, ": "
986+
"MAGMA, as a linear algebra backend, is deprecated and will be removed "
987+
"in future releases.",
988+
warn_force_cusolver
982989
);
983990
}
984991
}
@@ -1968,6 +1975,7 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit
19681975
'infos' is an int Tensor containing error codes for each matrix in the batched input.
19691976
For more information see MAGMA's documentation for GEEV routine.
19701977
*/
1978+
#ifdef USE_ROCM
19711979
template <typename scalar_t>
19721980
void apply_magma_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
19731981
#if !AT_MAGMA_ENABLED()
@@ -2045,25 +2053,22 @@ void linalg_eig_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
20452053
eigenvectors.copy_(eigenvectors_cpu);
20462054
infos.copy_(infos_cpu);
20472055
}
2056+
#endif // USE_ROCM
2057+
20482058
void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
2059+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
20492060
// This function calculates the non-symmetric eigendecomposition in-place
20502061
// tensors should be in batched column major memory format
20512062
// the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or
20522063
// 'linalg_eig_cusolver_xgeev' both geev routines modify the provided input matrix in-place, therefore we need a copy
2053-
2054-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
2055-
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
2056-
auto preferred_backend = at::globalContext().linalgPreferredBackend();
2057-
switch (preferred_backend) {
2058-
case at::LinalgBackend::Cusolver:
2059-
default:
2060-
linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
2061-
return;
2062-
case at::LinalgBackend::Magma:
2063-
break; // MAGMA path handled below
2064-
}
2065-
#endif
2064+
#ifndef USE_ROCM
2065+
_warn_once_magma_deprecation("linalg.eig");
2066+
linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
2067+
#else
2068+
// hipSolver does not have `geev`
2069+
_warn_once_magma_deprecation("linalg.eig", /*force_cusolver=*/false);
20662070
linalg_eig_magma(eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
2071+
#endif
20672072
}
20682073

20692074
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues,
7272
bool upper,
7373
bool compute_eigenvectors);
7474

75-
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
75+
#if defined(CUSOLVER_VERSION)
7676
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues,
7777
const Tensor& eigenvectors,
7878
const Tensor& input,

test/test_linalg.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
runOnRocmArch, MI200_ARCH, MI300_ARCH, MI350_ARCH, NAVI_ARCH, TEST_CUDA)
3333
from torch.testing._internal.common_device_type import \
3434
(instantiate_device_type_tests, dtypes, has_cusolver, onlyCPU, skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
35+
skipCUDAIfNoCusolverROCMIfNoMagma,
3536
skipCUDAIfNoCusolver, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA,
3637
onlyCUDA, skipMeta, skipCUDAIfNotRocm, dtypesIfMPS, largeTensorTest)
3738
from torch.testing import make_tensor
@@ -2255,7 +2256,7 @@ def test_norm_fastpaths(self, device):
22552256
self.assertEqual(result, expected)
22562257

22572258
@skipCPUIfNoLapack
2258-
@skipCUDAIfNoMagma
2259+
@skipCUDAIfNoCusolverROCMIfNoMagma
22592260
# NumPy computes only in float64 and complex128 precisions
22602261
# for float32 or complex64 results might be very different from float64 or complex128
22612262
@dtypes(torch.float64, torch.complex128)
@@ -2304,7 +2305,7 @@ def run_test(shape, *, symmetric=False):
23042305
run_test(shape, symmetric=True)
23052306

23062307
@onlyCUDA
2307-
@skipCUDAIfNoMagma
2308+
@skipCUDAIfNoCusolverROCMIfNoMagma
23082309
@dtypes(*floating_and_complex_types())
23092310
def test_eig_identity(self, device, dtype):
23102311

@@ -2414,7 +2415,7 @@ def run_test(shape, *, symmetric=False):
24142415

24152416

24162417
@onlyCUDA
2417-
@skipCUDAIfNoMagmaAndNoCusolver
2418+
@skipCUDAIfNoCusolverROCMIfNoMagma
24182419
@dtypes(*floating_and_complex_types())
24192420
def test_eig_out_variants(self, device, dtype):
24202421
from torch.testing._internal.common_utils import random_symmetric_matrix
@@ -2466,19 +2467,8 @@ def run_test(shape, *, symmetric=False):
24662467
run_test(shape, symmetric=True)
24672468

24682469

2469-
@slowTest
2470-
@onlyCUDA
2471-
@skipCUDAIfNoMagma
2472-
@dtypes(torch.float32)
2473-
def test_eig_check_magma(self, device, dtype):
2474-
# For CUDA inputs only matrices of size larger than 2048x2048 actually call MAGMA library
2475-
shape = (2049, 2049)
2476-
a = make_tensor(shape, dtype=dtype, device=device)
2477-
w, v = torch.linalg.eig(a)
2478-
# check correctness using eigendecomposition identity
2479-
self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3)
2480-
24812470
@onlyCUDA
2471+
@skipCUDAIfNoCusolverROCMIfNoMagma
24822472
@dtypes(torch.float32, torch.float64)
24832473
def test_eig_cuda_complex_eigenvectors(self, device, dtype):
24842474
"""Test CUDA eigenvector decoding with known ground truth, including batching."""
@@ -2563,8 +2553,8 @@ def test_eig_cuda_complex_eigenvectors(self, device, dtype):
25632553
rhs = vals_batch.unsqueeze(-2) * vecs_batch
25642554
self.assertEqual(lhs, rhs, atol=1e-5, rtol=1e-5)
25652555

2566-
@skipCUDAIfNoMagma
25672556
@skipCPUIfNoLapack
2557+
@skipCUDAIfNoCusolverROCMIfNoMagma
25682558
@dtypes(*floating_and_complex_types())
25692559
def test_eig_errors_and_warnings(self, device, dtype):
25702560
# eig requires the input to be at least 2 dimensional tensor
@@ -2626,7 +2616,7 @@ def test_eig_errors_and_warnings(self, device, dtype):
26262616
torch.linalg.eig(a, out=(out_w, out_v))
26272617

26282618
@skipCPUIfNoLapack
2629-
@skipCUDAIfNoMagma
2619+
@skipCUDAIfNoCusolverROCMIfNoMagma
26302620
@dtypes(*floating_and_complex_types())
26312621
def test_eig_with_nan(self, device, dtype):
26322622
for val in [np.inf, np.nan]:
@@ -3110,8 +3100,8 @@ def mul_svd_factors(U, S, Vh):
31103100
S_s = torch.svd(A, compute_uv=False).S
31113101
self.assertEqual(S_s, S)
31123102

3113-
@skipCUDAIfNoMagmaAndNoCusolver
31143103
@skipCPUIfNoLapack
3104+
@skipCUDAIfNoCusolverROCMIfNoMagma
31153105
@dtypes(torch.complex128)
31163106
def test_invariance_error_spectral_decompositions(self, device, dtype):
31173107
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)

torch/testing/_internal/common_device_type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,15 @@ def has_hipsolver():
18281828
return rocm_version >= (5, 3)
18291829

18301830

1831+
# Skips a test on CUDA if cuSOLVER is not available,
1832+
# and on ROCm if MAGMA is not available.
1833+
def skipCUDAIfNoCusolverROCMIfNoMagma(fn):
1834+
if TEST_WITH_ROCM:
1835+
return skipCUDAIfNoMagma(fn)
1836+
else:
1837+
return skipCUDAIfNoCusolver(fn)
1838+
1839+
18311840
# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available
18321841
def skipCUDAIfNoCusolver(fn):
18331842
return skipCUDAIf(

0 commit comments

Comments
 (0)