Skip to content

Commit 6b036a1

Browse files
xinyazhangpragupta
authored andcommitted
Enable gesvda for ROCM >= 6.1 (#1339)
This also fixes a problem in gesvd driver when UV is not needed. (cherry picked from commit 4ce57ec) (cherry picked from commit 167b4c1)
1 parent 061f369 commit 6b036a1

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,11 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
332332
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
333333
const auto not_A_H = A.size(-2) >= A.size(-1);
334334
Tensor Vcopy = V; // Shallow copy
335-
#ifdef USE_ROCM
335+
#ifdef ROCM_VERSION
336336
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
337337
// not guaranteed to be column major on ROCM, we have to create a copy to
338338
// deal with this
339-
if (!not_A_H) {
339+
if (compute_uv && !not_A_H) {
340340
Vcopy = at::empty_like(V.mT(),
341341
V.options()
342342
.device(V.device())
@@ -351,8 +351,8 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
351351
infos,
352352
full_matrices, compute_uv, calculate_all_batches, batches);
353353
});
354-
#ifdef USE_ROCM
355-
if (!not_A_H) {
354+
#ifdef ROCM_VERSION
355+
if (compute_uv && !not_A_H) {
356356
V.copy_(Vcopy);
357357
}
358358
#endif
@@ -526,8 +526,8 @@ static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const T
526526
template<typename scalar_t>
527527
static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
528528
const Tensor& infos, bool full_matrices, bool compute_uv) {
529-
#ifndef CUDART_VERSION
530-
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
529+
#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
530+
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0.")
531531
#else
532532
using value_t = typename c10::scalar_value_type<scalar_t>::type;
533533
int m = cuda_int_cast(A.size(-2), "m");
@@ -665,7 +665,7 @@ void svd_cusolver(const Tensor& A,
665665
static constexpr const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
666666

667667
// The default heuristic is to use gesvdj driver
668-
#ifdef USE_ROCM
668+
#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
669669
const auto driver_v = std::string_view("gesvdj");
670670
#else
671671
const auto driver_v = driver.value_or("gesvdj");

aten/src/ATen/native/cuda/linalg/CUDASolver.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ void gesvdjBatched<c10::complex<double>>(
470470
}
471471

472472

473-
// ROCM does not implement gesdva yet
474-
#ifdef CUDART_VERSION
473+
// ROCM does not implement gesdva correctly before 6.1
474+
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 60100
475475
template<>
476476
void gesvdaStridedBatched_buffersize<float>(
477477
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,

0 commit comments

Comments
 (0)