@@ -332,11 +332,11 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
332332 // gesvd just knows how to handle m >= n, so in the other case we need to transpose A
333333 const auto not_A_H = A.size (-2 ) >= A.size (-1 );
334334 Tensor Vcopy = V; // Shallow copy
335- #ifdef USE_ROCM
335+ #ifdef ROCM_VERSION
336336 // Similar to the case in svd_magma(), experiments have shown Vh tensor is
337337 // not guaranteed to be column major on ROCM, we have to create a copy to
338338 // deal with this
339- if (!not_A_H) {
339+ if (compute_uv && !not_A_H) {
340340 Vcopy = at::empty_like (V.mT (),
341341 V.options ()
342342 .device (V.device ())
@@ -351,8 +351,8 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
351351 infos,
352352 full_matrices, compute_uv, calculate_all_batches, batches);
353353 });
354- #ifdef USE_ROCM
355- if (!not_A_H) {
354+ #ifdef ROCM_VERSION
355+ if (compute_uv && !not_A_H) {
356356 V.copy_ (Vcopy);
357357 }
358358#endif
@@ -526,8 +526,8 @@ static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const T
526526template <typename scalar_t >
527527static void apply_svd_cusolver_gesvdaStridedBatched (const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
528528 const Tensor& infos, bool full_matrices, bool compute_uv) {
529- #ifndef CUDART_VERSION
530- TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend." )
529+ #if defined( CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
530+ TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0 ." )
531531#else
532532 using value_t = typename c10::scalar_value_type<scalar_t >::type;
533533 int m = cuda_int_cast (A.size (-2 ), " m" );
@@ -665,7 +665,7 @@ void svd_cusolver(const Tensor& A,
665665 static constexpr const char * check_svd_doc = " Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html" ;
666666
667667 // The default heuristic is to use gesvdj driver
668- #ifdef USE_ROCM
668+ #if defined(ROCM_VERSION) && ROCM_VERSION < 60100
669669 const auto driver_v = std::string_view (" gesvdj" );
670670#else
671671 const auto driver_v = driver.value_or (" gesvdj" );
0 commit comments