From 55461eef63e5d1c581a6303491d62f9e7fde8970 Mon Sep 17 00:00:00 2001 From: critsium-xy Date: Tue, 7 Jan 2025 22:03:09 +0800 Subject: [PATCH 1/3] initial commit --- source/module_base/blas_connector.cpp | 141 +++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 4 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 3bb91e2f01..acf55a0518 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -50,6 +50,36 @@ namespace BlasUtils{ return CUBLAS_OP_N; } + cublasSideMode_t judge_side(const char& trans) + { + if (trans == "L") + { + return CUBLAS_SIDE_LEFT; + } + else if (trans == "R") + { + return CUBLAS_SIDE_RIGHT; + } + return CUBLAS_SIDE_LEFT; + } + + cublasFillMode_t judge_fill(const char& trans) + { + if (trans == "F") + { + return CUBLAS_FILL_MODE_FULL; + } + else if (trans == "U") + { + return CUBLAS_FILL_MODE_UPPER; + } + else if (trans == "D") + { + return CUBLAS_FILL_MODE_LOWER; + } + return CUBLAS_FILL_MODE_FULL; + } + } // namespace BlasUtils #endif @@ -184,6 +214,22 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d return ddot_(&n, X, &incX, Y, &incY); } +double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + return ddot_(&n, X, &incX, Y, &incY); + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + double result = 0.0; + cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); + return result; +#endif + } + return ddot_(&n, X, &incX, Y, &incY); +} + + // C = a * A.? * B.? + b * C // Row-Major part void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -398,6 +444,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(transa); + cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); +#endif + } } void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, @@ -409,6 +462,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(transa); + cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } } void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, @@ -420,6 +480,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(transa); + cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); +#endif + } } void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, @@ -431,6 +498,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(transa); + cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); +#endif + } } void BlasConnector::hemm_cm(char side, char uplo, int m, int n, @@ -442,6 +516,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(transa); + cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); +#endif + } } void BlasConnector::hemm_cm(char side, char uplo, int m, int n, @@ -453,6 +534,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasSideMode_t sideMode = BlasUtils::judge_side(transa); + cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -461,7 +549,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op"); + cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -470,7 +564,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op"); + cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -479,7 +579,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op"); + cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incx, (float2*)&beta, (float2*)Y, incy)); +#endif + } } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -488,7 +594,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n, { if (device_type == base_device::AbacusDevice_t::CpuDevice) { zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op"); + cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incx, (double2*)&beta, (double2*)Y, incy)); +#endif + } } // out = ||x||_2 @@ -497,6 +609,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev if (device_type == base_device::AbacusDevice_t::CpuDevice) { return snrm2_( &n, X, &incX ); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + float result = 0.0; + cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); + return result; +#endif + } return snrm2_( &n, X, &incX ); } @@ -506,6 +625,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dnrm2_( &n, X, &incX ); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + double result = 0.0; + cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); + return result; +#endif + } return dnrm2_( &n, X, &incX ); } @@ -515,6 +641,13 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dznrm2_( &n, X, &incX ); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + std::complex result = 0.0; + cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, (double2*)&result)); + return result; +#endif + } return dznrm2_( &n, X, &incX ); } From 0e4118502aa94bfd1938b526dc38ac4e25fdae6a Mon Sep 17 00:00:00 2001 From: critsium-xy Date: Wed, 8 Jan 2025 11:49:51 +0800 Subject: [PATCH 2/3] Fix compiling error --- source/module_base/blas_connector.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index acf55a0518..8347ef29d0 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -214,22 +214,6 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d return ddot_(&n, X, &incX, Y, &incY); } -double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - return ddot_(&n, X, &incX, Y, &incY); - } - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - double result = 0.0; - cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); - return result; -#endif - } - return ddot_(&n, X, &incX, Y, &incY); -} - - // C = a * A.? * B.? + b * C // Row-Major part void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, From 583920aac5a48c465f6d4db2d16b68c161d7b08e Mon Sep 17 00:00:00 2001 From: critsium-xy Date: Wed, 8 Jan 2025 15:52:57 +0800 Subject: [PATCH 3/3] Fix trans comparison bug --- source/module_base/blas_connector.cpp | 56 ++++++++++++++------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 8347ef29d0..106eb6c4e8 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -52,11 +52,11 @@ namespace BlasUtils{ cublasSideMode_t judge_side(const char& trans) { - if (trans == "L") + if (trans == 'L') { return CUBLAS_SIDE_LEFT; } - else if (trans == "R") + else if (trans == 'R') { return CUBLAS_SIDE_RIGHT; } @@ -65,15 +65,15 @@ namespace BlasUtils{ cublasFillMode_t judge_fill(const char& trans) { - if (trans == "F") + if (trans == 'F') { return CUBLAS_FILL_MODE_FULL; } - else if (trans == "U") + else if (trans == 'U') { return CUBLAS_FILL_MODE_UPPER; } - else if (trans == "D") + else if (trans == 'D') { return CUBLAS_FILL_MODE_LOWER; } @@ -430,9 +430,9 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasSideMode_t sideMode = BlasUtils::judge_side(transa); - cublasOperation_t fillMode = BlasUtils::judge_fill(transb); - cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); #endif } } @@ -448,8 +448,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasSideMode_t sideMode = BlasUtils::judge_side(transa); - cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); #endif } @@ -466,8 +466,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasSideMode_t sideMode = BlasUtils::judge_side(transa); - cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); #endif } @@ -484,8 +484,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasSideMode_t sideMode = BlasUtils::judge_side(transa); - cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); #endif } @@ -502,8 +502,8 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasSideMode_t sideMode = BlasUtils::judge_side(transa); - cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); #endif } @@ -520,8 +520,8 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasSideMode_t sideMode = BlasUtils::judge_side(transa); - cublasOperation_t fillMode = BlasUtils::judge_fill(transb); + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); #endif } @@ -536,7 +536,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n, } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op"); + cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); #endif } @@ -551,7 +551,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n, } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op"); + cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); #endif } @@ -566,8 +566,10 @@ void BlasConnector::gemv(const char trans, const int m, const int n, } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op"); - cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incx, (float2*)&beta, (float2*)Y, incy)); + cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag()); + cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag()); + cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); + cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy)); #endif } } @@ -581,8 +583,10 @@ void BlasConnector::gemv(const char trans, const int m, const int n, } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op"); - cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incx, (double2*)&beta, (double2*)Y, incy)); + cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag()); + cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag()); + cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); + cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy)); #endif } } @@ -627,8 +631,8 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - std::complex result = 0.0; - cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, (double2*)&result)); + double result = 0.0; + cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result)); return result; #endif }