Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 183 additions & 24 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,101 @@
#include "module_base/global_variable.h"
#endif

#ifdef __CUDA
#include <base/macros/macros.h>
#include <cuda_runtime.h>
#include <thrust/complex.h>
#include <thrust/execution_policy.h>
#include <thrust/inner_product.h>
#include "module_base/tool_quit.h"

#include "cublas_v2.h"

namespace BlasUtils{

static cublasHandle_t cublas_handle = nullptr;

void createGpuBlasHandle(){
if (cublas_handle == nullptr) {
cublasErrcheck(cublasCreate(&cublas_handle));
}
}

void destoryBLAShandle(){
if (cublas_handle != nullptr) {
cublasErrcheck(cublasDestroy(cublas_handle));
cublas_handle = nullptr;
}
}


cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
return CUBLAS_OP_N;
}

} // namespace BlasUtils

#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
saxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
daxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
caxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
#endif
}
}


Expand All @@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
#endif
}
}


Expand All @@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
float result = 0.0;
cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return sdot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return ddot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

// Col-Major part
Expand All @@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

// Symm and Hemm part. Only col-major is supported.
Expand Down
Loading
Loading