55#include " module_base/global_variable.h"
66#endif
77
8+ #ifdef __CUDA
9+ #include < base/macros/macros.h>
10+ #include < cuda_runtime.h>
11+ #include < thrust/complex.h>
12+ #include < thrust/execution_policy.h>
13+ #include < thrust/inner_product.h>
14+
15+ static cublasHandle_t cublas_handle = nullptr ;
16+
17+ void createGpuBlasHandle (){
18+ if (cublas_handle == nullptr ) {
19+ cublasErrcheck (cublasCreate (&cublas_handle));
20+ }
21+ }
22+
23+ void destoryBLAShandle (){
24+ if (cublas_handle != nullptr ) {
25+ cublasErrcheck (cublasDestroy (cublas_handle));
26+ cublas_handle = nullptr ;
27+ }
28+ }
29+
30+ cublasOperation_t judge_trans_op (bool is_complex, const char & trans, const char * name)
31+ {
32+ if (trans == ' N' )
33+ {
34+ return CUBLAS_OP_N;
35+ }
36+ else if (trans == ' T' )
37+ {
38+ return CUBLAS_OP_T;
39+ }
40+ else if (is_complex && trans == ' C' )
41+ {
42+ return CUBLAS_OP_C;
43+ }
44+ else
45+ {
46+ ModuleBase::WARNING_QUIT (name, std::string (" Unknown trans type " ) + trans + std::string (" !" ));
47+ }
48+ }
49+
50+ #endif
51+
852void BlasConnector::axpy ( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
953{
1054 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1155 saxpy_ (&n, &alpha, X, &incX, Y, &incY);
12- }
56+ }
57+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
58+ #ifdef __CUDA
59+ cublasErrcheck (cublasSaxpy (cublas_handle, n, alpha, X, incX, Y, incY));
60+ #endif
61+ }
1362}
1463
1564void BlasConnector::axpy ( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
1665{
1766 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1867 daxpy_ (&n, &alpha, X, &incX, Y, &incY);
19- }
68+ }
69+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
70+ #ifdef __CUDA
71+ cublasErrcheck (cublasDaxpy (cublas_handle, n, alpha, X, incX, Y, incY));
72+ #endif
73+ }
2074}
2175
2276void BlasConnector::axpy ( const int n, const std::complex <float > alpha, const std::complex <float > *X, const int incX, std::complex <float > *Y, const int incY, base_device::AbacusDevice_t device_type)
2377{
2478 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
2579 caxpy_ (&n, &alpha, X, &incX, Y, &incY);
26- }
80+ }
81+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
82+ #ifdef __CUDA
83+ cublasErrcheck (cublasCaxpy (cublas_handle, n, alpha, X, incX, Y, incY));
84+ #endif
85+ }
2786}
2887
2988void BlasConnector::axpy ( const int n, const std::complex <double > alpha, const std::complex <double > *X, const int incX, std::complex <double > *Y, const int incY, base_device::AbacusDevice_t device_type)
3089{
3190 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
3291 zaxpy_ (&n, &alpha, X, &incX, Y, &incY);
33- }
92+ }
93+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
94+ #ifdef __CUDA
95+ cublasErrcheck (cublasZaxpy (cublas_handle, n, alpha, X, incX, Y, incY));
96+ #endif
97+ }
3498}
3599
36100
@@ -39,28 +103,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
39103{
40104 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
41105 sscal_ (&n, &alpha, X, &incX);
42- }
106+ }
107+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
108+ #ifdef __CUDA
109+ cublasErrcheck (cublasSscal (cublas_handle, n, (float2*)alpha, (float2*)X, incx));
110+ #endif
111+ }
43112}
44113
45114void BlasConnector::scal ( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
46115{
47116 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
48117 dscal_ (&n, &alpha, X, &incX);
49- }
118+ }
119+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
120+ #ifdef __CUDA
121+ cublasErrcheck (cublasDscal (cublas_handle, n, (double2*)alpha, (double2*)X, incx));
122+ #endif
123+ }
50124}
51125
52126void BlasConnector::scal ( const int n, const std::complex <float > alpha, std::complex <float > *X, const int incX, base_device::AbacusDevice_t device_type)
53127{
54128 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
55129 cscal_ (&n, &alpha, X, &incX);
56- }
130+ }
131+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
132+ #ifdef __CUDA
133+ cublasErrcheck (cublasCscal (cublas_handle, n, (float2*)alpha, (float2*)X, incx));
134+ #endif
135+ }
57136}
58137
59138void BlasConnector::scal ( const int n, const std::complex <double > alpha, std::complex <double > *X, const int incX, base_device::AbacusDevice_t device_type)
60139{
61140 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
62141 zscal_ (&n, &alpha, X, &incX);
63- }
142+ }
143+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
144+ #ifdef __CUDA
145+ cublasErrcheck (cublasZscal (cublas_handle, n, (double2*)alpha, (double2*)X, incx));
146+ #endif
147+ }
64148}
65149
66150
@@ -70,6 +154,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
70154 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
71155 return sdot_ (&n, X, &incX, Y, &incY);
72156 }
157+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
158+ #ifdef __CUDA
159+ float result = 0.0 ;
160+ cublasErrcheck (cublasSdot (cublas_handle, n, X, incx, Y, incy, &result));
161+ return result;
162+ #endif
163+ }
73164 return sdot_ (&n, X, &incX, Y, &incY);
74165}
75166
@@ -78,6 +169,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
78169 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
79170 return ddot_ (&n, X, &incX, Y, &incY);
80171 }
172+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
173+ #ifdef __CUDA
174+ double result = 0.0 ;
175+ cublasErrcheck (cublasDdot (cublas_handle, n, X, incx, Y, incy, &result));
176+ return result;
177+ #endif
178+ }
81179 return ddot_ (&n, X, &incX, Y, &incY);
82180}
83181
@@ -91,13 +189,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
91189 &alpha, b, &ldb, a, &lda,
92190 &beta, c, &ldc);
93191 }
94- #ifdef __DSP
192+ #ifdef __DSP
95193 else if (device_type == base_device::AbacusDevice_t::DspDevice){
96194 sgemm_mth_ (&transb, &transa, &n, &m, &k,
97195 &alpha, b, &ldb, a, &lda,
98196 &beta, c, &ldc, GlobalV::MY_RANK);
99197 }
100- #endif
198+ #endif
199+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
200+ #ifdef __CUDA
201+ cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
202+ cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
203+ cublasErrcheck (cublasSgemm (cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
204+ #endif
205+ }
101206}
102207
103208void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -109,13 +214,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
109214 &alpha, b, &ldb, a, &lda,
110215 &beta, c, &ldc);
111216 }
112- #ifdef __DSP
217+ #ifdef __DSP
113218 else if (device_type == base_device::AbacusDevice_t::DspDevice){
114219 dgemm_mth_ (&transb, &transa, &n, &m, &k,
115220 &alpha, b, &ldb, a, &lda,
116221 &beta, c, &ldc, GlobalV::MY_RANK);
117222 }
118- #endif
223+ #endif
224+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
225+ #ifdef __CUDA
226+ cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
227+ cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
228+ cublasErrcheck (cublasDgemm (cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
229+ #endif
230+ }
119231}
120232
121233void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -127,13 +239,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
127239 &alpha, b, &ldb, a, &lda,
128240 &beta, c, &ldc);
129241 }
130- #ifdef __DSP
242+ #ifdef __DSP
131243 else if (device_type == base_device::AbacusDevice_t::DspDevice) {
132244 cgemm_mth_ (&transb, &transa, &n, &m, &k,
133245 &alpha, b, &ldb, a, &lda,
134246 &beta, c, &ldc, GlobalV::MY_RANK);
135247 }
136- #endif
248+ #endif
249+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
250+ #ifdef __CUDA
251+ cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
252+ cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
253+ cublasErrcheck (cublasCgemm (cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
254+ #endif
255+ }
137256}
138257
139258void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -145,13 +264,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
145264 &alpha, b, &ldb, a, &lda,
146265 &beta, c, &ldc);
147266 }
148- #ifdef __DSP
267+ #ifdef __DSP
149268 else if (device_type == base_device::AbacusDevice_t::DspDevice) {
150269 zgemm_mth_ (&transb, &transa, &n, &m, &k,
151270 &alpha, b, &ldb, a, &lda,
152271 &beta, c, &ldc, GlobalV::MY_RANK);
153272 }
154- #endif
273+ #endif
274+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
275+ #ifdef __CUDA
276+ cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
277+ cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
278+ cublasErrcheck (cublasZgemm (cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
279+ #endif
280+ }
155281}
156282
157283void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -160,7 +286,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
160286{
161287 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162288 sgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
163- }
289+ }
290+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
291+ #ifdef __CUDA
292+ cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
293+ cublasErrcheck (cublasSgemv (cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
294+ #endif
295+ }
164296}
165297
166298void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -169,7 +301,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
169301{
170302 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
171303 dgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
172- }
304+ }
305+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
306+ #ifdef __CUDA
307+ cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
308+ cublasErrcheck (cublasDgemv (cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
309+ #endif
310+ }
173311}
174312
175313void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -178,7 +316,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
178316{
179317 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
180318 cgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
181- }
319+ }
320+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
321+ #ifdef __CUDA
322+ cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
323+ cublasErrcheck (cublasCgemv (cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
324+ #endif
325+ }
182326}
183327
184328void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -187,7 +331,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
187331{
188332 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
189333 zgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
190- }
334+ }
335+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
336+ #ifdef __CUDA
337+ cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
338+ cublasErrcheck (cublasZgemv (cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
339+ #endif
340+ }
191341}
192342
193343
0 commit comments