Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 125 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ namespace BlasUtils{
return CUBLAS_OP_N;
}

cublasSideMode_t judge_side(const char& trans)
{
if (trans == 'L')
{
return CUBLAS_SIDE_LEFT;
}
else if (trans == 'R')
{
return CUBLAS_SIDE_RIGHT;
}
return CUBLAS_SIDE_LEFT;
}

cublasFillMode_t judge_fill(const char& trans)
{
if (trans == 'F')
{
return CUBLAS_FILL_MODE_FULL;
}
else if (trans == 'U')
{
return CUBLAS_FILL_MODE_UPPER;
}
else if (trans == 'D')
{
return CUBLAS_FILL_MODE_LOWER;
}
return CUBLAS_FILL_MODE_FULL;
}

} // namespace BlasUtils

#endif
Expand Down Expand Up @@ -398,6 +428,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
Expand All @@ -409,6 +446,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
Expand All @@ -420,6 +464,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
Expand All @@ -431,6 +482,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
Expand All @@ -442,6 +500,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
Expand All @@ -453,6 +518,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -461,7 +533,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -470,7 +548,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -479,7 +563,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag());
cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag());
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -488,7 +580,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag());
cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag());
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
#endif
}
}

// out = ||x||_2
Expand All @@ -497,6 +597,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return snrm2_( &n, X, &incX );
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
float result = 0.0;
cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
return result;
#endif
}
return snrm2_( &n, X, &incX );
}

Expand All @@ -506,6 +613,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dnrm2_( &n, X, &incX );
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
return result;
#endif
}
return dnrm2_( &n, X, &incX );
}

Expand All @@ -515,6 +629,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dznrm2_( &n, X, &incX );
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
return result;
#endif
}
return dznrm2_( &n, X, &incX );
}

Expand Down
Loading