diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 3bb91e2f01..106eb6c4e8 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -50,6 +50,36 @@ namespace BlasUtils{ return CUBLAS_OP_N; } + cublasSideMode_t judge_side(const char& trans) + { + if (trans == 'L') + { + return CUBLAS_SIDE_LEFT; + } + else if (trans == 'R') + { + return CUBLAS_SIDE_RIGHT; + } + return CUBLAS_SIDE_LEFT; + } + + cublasFillMode_t judge_fill(const char& trans) + { + if (trans == 'F') + { + return CUBLAS_FILL_MODE_FULL; + } + else if (trans == 'U') + { + return CUBLAS_FILL_MODE_UPPER; + } + else if (trans == 'D') + { + return CUBLAS_FILL_MODE_LOWER; + } + return CUBLAS_FILL_MODE_FULL; + } + } // namespace BlasUtils #endif @@ -398,6 +428,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } } void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, @@ -409,6 +446,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } } void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, @@ -420,6 +464,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); +#endif + } } void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, @@ -431,6 +482,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); +#endif + } } void BlasConnector::hemm_cm(char side, char uplo, int m, int n, @@ -442,6 +500,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); +#endif + } } void BlasConnector::hemm_cm(char side, char uplo, int m, int n, @@ -453,6 +518,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -461,7 +533,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); + cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -470,7 +548,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); + cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -479,7 +563,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag()); + cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag()); + cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); + cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -488,7 +580,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag()); + cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag()); + cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); + cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy)); +#endif + } } // out = ||x||_2 @@ -497,6 +597,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev if (device_type == base_device::AbacusDevice_t::CpuDevice) { return snrm2_( &n, X, &incX ); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + float result = 0.0; + cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); + return result; +#endif + } return snrm2_( &n, X, &incX ); } @@ -506,6 +613,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dnrm2_( &n, X, &incX ); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + double result = 0.0; + cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); + return result; +#endif + } return dnrm2_( &n, X, &incX ); } @@ -515,6 +629,13 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dznrm2_( &n, X, &incX ); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + double result = 0.0; + cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result)); + return result; +#endif + } return dznrm2_( &n, X, &incX ); }