Skip to content

Commit 6cc4706

Browse files
committed
Complete CUDA implement on blas_connector.cpp
1 parent e555490 commit 6cc4706

File tree

1 file changed

+170
-20
lines changed

1 file changed

+170
-20
lines changed

source/module_base/blas_connector.cpp

Lines changed: 170 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,96 @@
55
#include "module_base/global_variable.h"
66
#endif
77

8+
#ifdef __CUDA
9+
#include <base/macros/macros.h>
10+
#include <cuda_runtime.h>
11+
#include <thrust/complex.h>
12+
#include <thrust/execution_policy.h>
13+
#include <thrust/inner_product.h>
14+
15+
static cublasHandle_t cublas_handle = nullptr;
16+
17+
void createGpuBlasHandle(){
18+
if (cublas_handle == nullptr) {
19+
cublasErrcheck(cublasCreate(&cublas_handle));
20+
}
21+
}
22+
23+
void destoryBLAShandle(){
24+
if (cublas_handle != nullptr) {
25+
cublasErrcheck(cublasDestroy(cublas_handle));
26+
cublas_handle = nullptr;
27+
}
28+
}
29+
30+
cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
31+
{
32+
if (trans == 'N')
33+
{
34+
return CUBLAS_OP_N;
35+
}
36+
else if(trans == 'T')
37+
{
38+
return CUBLAS_OP_T;
39+
}
40+
else if(is_complex && trans == 'C')
41+
{
42+
return CUBLAS_OP_C;
43+
}
44+
else
45+
{
46+
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
47+
}
48+
}
49+
50+
#endif
51+
852
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
953
{
1054
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1155
saxpy_(&n, &alpha, X, &incX, Y, &incY);
12-
}
56+
}
57+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
58+
#ifdef __CUDA
59+
cublasErrcheck(cublasSaxpy(cublas_handle, n, alpha, X, incX, Y, incY));
60+
#endif
61+
}
1362
}
1463

1564
void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
1665
{
1766
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1867
daxpy_(&n, &alpha, X, &incX, Y, &incY);
19-
}
68+
}
69+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
70+
#ifdef __CUDA
71+
cublasErrcheck(cublasDaxpy(cublas_handle, n, alpha, X, incX, Y, incY));
72+
#endif
73+
}
2074
}
2175

2276
void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
2377
{
2478
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
2579
caxpy_(&n, &alpha, X, &incX, Y, &incY);
26-
}
80+
}
81+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
82+
#ifdef __CUDA
83+
cublasErrcheck(cublasCaxpy(cublas_handle, n, alpha, X, incX, Y, incY));
84+
#endif
85+
}
2786
}
2887

2988
void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
3089
{
3190
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
3291
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
33-
}
92+
}
93+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
94+
#ifdef __CUDA
95+
cublasErrcheck(cublasZaxpy(cublas_handle, n, alpha, X, incX, Y, incY));
96+
#endif
97+
}
3498
}
3599

36100

@@ -39,28 +103,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
39103
{
40104
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
41105
sscal_(&n, &alpha, X, &incX);
42-
}
106+
}
107+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
108+
#ifdef __CUDA
109+
cublasErrcheck(cublasSscal(cublas_handle, n, (float2*)alpha, (float2*)X, incx));
110+
#endif
111+
}
43112
}
44113

45114
void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
46115
{
47116
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
48117
dscal_(&n, &alpha, X, &incX);
49-
}
118+
}
119+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
120+
#ifdef __CUDA
121+
cublasErrcheck(cublasDscal(cublas_handle, n, (double2*)alpha, (double2*)X, incx));
122+
#endif
123+
}
50124
}
51125

52126
void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
53127
{
54128
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
55129
cscal_(&n, &alpha, X, &incX);
56-
}
130+
}
131+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
132+
#ifdef __CUDA
133+
cublasErrcheck(cublasCscal(cublas_handle, n, (float2*)alpha, (float2*)X, incx));
134+
#endif
135+
}
57136
}
58137

59138
void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
60139
{
61140
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
62141
zscal_(&n, &alpha, X, &incX);
63-
}
142+
}
143+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
144+
#ifdef __CUDA
145+
cublasErrcheck(cublasZscal(cublas_handle, n, (double2*)alpha, (double2*)X, incx));
146+
#endif
147+
}
64148
}
65149

66150

@@ -70,6 +154,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
70154
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
71155
return sdot_(&n, X, &incX, Y, &incY);
72156
}
157+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
158+
#ifdef __CUDA
159+
float result = 0.0;
160+
cublasErrcheck(cublasSdot(cublas_handle, n, X, incx, Y, incy, &result));
161+
return result;
162+
#endif
163+
}
73164
return sdot_(&n, X, &incX, Y, &incY);
74165
}
75166

@@ -78,6 +169,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
78169
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
79170
return ddot_(&n, X, &incX, Y, &incY);
80171
}
172+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
173+
#ifdef __CUDA
174+
double result = 0.0;
175+
cublasErrcheck(cublasDdot(cublas_handle, n, X, incx, Y, incy, &result));
176+
return result;
177+
#endif
178+
}
81179
return ddot_(&n, X, &incX, Y, &incY);
82180
}
83181

@@ -91,13 +189,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
91189
&alpha, b, &ldb, a, &lda,
92190
&beta, c, &ldc);
93191
}
94-
#ifdef __DSP
192+
#ifdef __DSP
95193
else if (device_type == base_device::AbacusDevice_t::DspDevice){
96194
sgemm_mth_(&transb, &transa, &n, &m, &k,
97195
&alpha, b, &ldb, a, &lda,
98196
&beta, c, &ldc, GlobalV::MY_RANK);
99197
}
100-
#endif
198+
#endif
199+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
200+
#ifdef __CUDA
201+
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
202+
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
203+
cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
204+
#endif
205+
}
101206
}
102207

103208
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -109,13 +214,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
109214
&alpha, b, &ldb, a, &lda,
110215
&beta, c, &ldc);
111216
}
112-
#ifdef __DSP
217+
#ifdef __DSP
113218
else if (device_type == base_device::AbacusDevice_t::DspDevice){
114219
dgemm_mth_(&transb, &transa, &n, &m, &k,
115220
&alpha, b, &ldb, a, &lda,
116221
&beta, c, &ldc, GlobalV::MY_RANK);
117222
}
118-
#endif
223+
#endif
224+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
225+
#ifdef __CUDA
226+
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
227+
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
228+
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
229+
#endif
230+
}
119231
}
120232

121233
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -127,13 +239,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
127239
&alpha, b, &ldb, a, &lda,
128240
&beta, c, &ldc);
129241
}
130-
#ifdef __DSP
242+
#ifdef __DSP
131243
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
132244
cgemm_mth_(&transb, &transa, &n, &m, &k,
133245
&alpha, b, &ldb, a, &lda,
134246
&beta, c, &ldc, GlobalV::MY_RANK);
135247
}
136-
#endif
248+
#endif
249+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
250+
#ifdef __CUDA
251+
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
252+
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
253+
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
254+
#endif
255+
}
137256
}
138257

139258
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -145,13 +264,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
145264
&alpha, b, &ldb, a, &lda,
146265
&beta, c, &ldc);
147266
}
148-
#ifdef __DSP
267+
#ifdef __DSP
149268
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
150269
zgemm_mth_(&transb, &transa, &n, &m, &k,
151270
&alpha, b, &ldb, a, &lda,
152271
&beta, c, &ldc, GlobalV::MY_RANK);
153272
}
154-
#endif
273+
#endif
274+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
275+
#ifdef __CUDA
276+
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
277+
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
278+
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
279+
#endif
280+
}
155281
}
156282

157283
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -160,7 +286,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
160286
{
161287
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162288
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
163-
}
289+
}
290+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
291+
#ifdef __CUDA
292+
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
293+
cublasErrcheck(cublasSgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
294+
#endif
295+
}
164296
}
165297

166298
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -169,7 +301,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
169301
{
170302
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
171303
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
172-
}
304+
}
305+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
306+
#ifdef __CUDA
307+
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
308+
cublasErrcheck(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
309+
#endif
310+
}
173311
}
174312

175313
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -178,7 +316,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
178316
{
179317
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
180318
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
181-
}
319+
}
320+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
321+
#ifdef __CUDA
322+
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
323+
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
324+
#endif
325+
}
182326
}
183327

184328
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -187,7 +331,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
187331
{
188332
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
189333
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
190-
}
334+
}
335+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
336+
#ifdef __CUDA
337+
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
338+
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
339+
#endif
340+
}
191341
}
192342

193343

0 commit comments

Comments
 (0)