Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 190 additions & 36 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,100 @@
#include "module_base/global_variable.h"
#endif

#ifdef __CUDA
#include <base/macros/macros.h>
#include <cuda_runtime.h>
#include <thrust/complex.h>
#include <thrust/execution_policy.h>
#include <thrust/inner_product.h>
#include "module_base/tool_quit.h"

#include "cublas_v2.h"

namespace BlasUtils{

static cublasHandle_t cublas_handle = nullptr;

void createGpuBlasHandle(){
if (cublas_handle == nullptr) {
cublasErrcheck(cublasCreate(&cublas_handle));
}
}

void destoryBLAShandle(){
if (cublas_handle != nullptr) {
cublasErrcheck(cublasDestroy(cublas_handle));
cublas_handle = nullptr;
}
}

} // namespace BlasUtils

cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
return CUBLAS_OP_N;
}

#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
saxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
daxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
caxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
#endif
}
}


Expand All @@ -39,28 +107,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
#endif
}
}


Expand All @@ -70,6 +158,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
float result = 0.0;
cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return sdot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -78,6 +173,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return ddot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -91,13 +193,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -109,13 +218,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -127,13 +243,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -145,49 +268,80 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type)
const float alpha, const float* A, const int lda, const float* X, const int incX,
const float beta, float* Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const double alpha, const double* A, const int lda, const double* X, const int incx,
const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type)
const double alpha, const double* A, const int lda, const double* X, const int incX,
const double beta, double* Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *A, const int lda, const std::complex<float> *X, const int incx,
const std::complex<float> beta, std::complex<float> *Y, const int incy, base_device::AbacusDevice_t device_type)
const std::complex<float> alpha, const std::complex<float> *A, const int lda, const std::complex<float> *X, const int incX,
const std::complex<float> beta, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutrans, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incX, (float2*)&beta, (float2*)Y, incY));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *A, const int lda, const std::complex<double> *X, const int incx,
const std::complex<double> beta, std::complex<double> *Y, const int incy, base_device::AbacusDevice_t device_type)
const std::complex<double> alpha, const std::complex<double> *A, const int lda, const std::complex<double> *X, const int incX,
const std::complex<double> beta, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutrans, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incX, (double2*)&beta, (double2*)Y, incY));
#endif
}
}


Expand Down Expand Up @@ -219,16 +373,16 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
}

// copies a into b
void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type)
void BlasConnector::copy(const long n, const double *a, const int incX, double *b, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dcopy_(&n, a, &incx, b, &incy);
dcopy_(&n, a, &incX, b, &incY);
}
}

void BlasConnector::copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy, base_device::AbacusDevice_t device_type)
void BlasConnector::copy(const long n, const std::complex<double> *a, const int incX, std::complex<double> *b, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zcopy_(&n, a, &incx, b, &incy);
zcopy_(&n, a, &incX, b, &incY);
}
}
Loading
Loading