55#include " module_base/global_variable.h"
66#endif
77
8+ #ifdef __CUDA
9+ #include < base/macros/macros.h>
10+ #include < cuda_runtime.h>
11+ #include < thrust/complex.h>
12+ #include < thrust/execution_policy.h>
13+ #include < thrust/inner_product.h>
14+ #include " module_base/tool_quit.h"
15+
16+ #include " cublas_v2.h"
17+
18+ namespace BlasUtils {
19+
20+ static cublasHandle_t cublas_handle = nullptr ;
21+
22+ void createGpuBlasHandle (){
23+ if (cublas_handle == nullptr ) {
24+ cublasErrcheck (cublasCreate (&cublas_handle));
25+ }
26+ }
27+
28+ void destoryBLAShandle (){
29+ if (cublas_handle != nullptr ) {
30+ cublasErrcheck (cublasDestroy (cublas_handle));
31+ cublas_handle = nullptr ;
32+ }
33+ }
34+
35+
36+ cublasOperation_t judge_trans (bool is_complex, const char & trans, const char * name)
37+ {
38+ if (trans == ' N' )
39+ {
40+ return CUBLAS_OP_N;
41+ }
42+ else if (trans == ' T' )
43+ {
44+ return CUBLAS_OP_T;
45+ }
46+ else if (is_complex && trans == ' C' )
47+ {
48+ return CUBLAS_OP_C;
49+ }
50+ return CUBLAS_OP_N;
51+ }
52+
53+ } // namespace BlasUtils
54+
55+ #endif
56+
857void BlasConnector::axpy ( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
958{
1059 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1160 saxpy_ (&n, &alpha, X, &incX, Y, &incY);
12- }
61+ }
62+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
63+ #ifdef __CUDA
64+ cublasErrcheck (cublasSaxpy (BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
65+ #endif
66+ }
1367}
1468
1569void BlasConnector::axpy ( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
1670{
1771 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1872 daxpy_ (&n, &alpha, X, &incX, Y, &incY);
19- }
73+ }
74+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
75+ #ifdef __CUDA
76+ cublasErrcheck (cublasDaxpy (BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
77+ #endif
78+ }
2079}
2180
2281void BlasConnector::axpy ( const int n, const std::complex <float > alpha, const std::complex <float > *X, const int incX, std::complex <float > *Y, const int incY, base_device::AbacusDevice_t device_type)
2382{
2483 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
2584 caxpy_ (&n, &alpha, X, &incX, Y, &incY);
26- }
85+ }
86+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
87+ #ifdef __CUDA
88+ cublasErrcheck (cublasCaxpy (BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
89+ #endif
90+ }
2791}
2892
2993void BlasConnector::axpy ( const int n, const std::complex <double > alpha, const std::complex <double > *X, const int incX, std::complex <double > *Y, const int incY, base_device::AbacusDevice_t device_type)
3094{
3195 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
3296 zaxpy_ (&n, &alpha, X, &incX, Y, &incY);
33- }
97+ }
98+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
99+ #ifdef __CUDA
100+ cublasErrcheck (cublasZaxpy (BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
101+ #endif
102+ }
34103}
35104
36105
@@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
39108{
40109 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
41110 sscal_ (&n, &alpha, X, &incX);
42- }
111+ }
112+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
113+ #ifdef __CUDA
114+ cublasErrcheck (cublasSscal (BlasUtils::cublas_handle, n, &alpha, X, incX));
115+ #endif
116+ }
43117}
44118
45119void BlasConnector::scal ( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
46120{
47121 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
48122 dscal_ (&n, &alpha, X, &incX);
49- }
123+ }
124+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
125+ #ifdef __CUDA
126+ cublasErrcheck (cublasDscal (BlasUtils::cublas_handle, n, &alpha, X, incX));
127+ #endif
128+ }
50129}
51130
52131void BlasConnector::scal ( const int n, const std::complex <float > alpha, std::complex <float > *X, const int incX, base_device::AbacusDevice_t device_type)
53132{
54133 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
55134 cscal_ (&n, &alpha, X, &incX);
56- }
135+ }
136+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
137+ #ifdef __CUDA
138+ cublasErrcheck (cublasCscal (BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
139+ #endif
140+ }
57141}
58142
59143void BlasConnector::scal ( const int n, const std::complex <double > alpha, std::complex <double > *X, const int incX, base_device::AbacusDevice_t device_type)
60144{
61145 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
62146 zscal_ (&n, &alpha, X, &incX);
63- }
147+ }
148+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
149+ #ifdef __CUDA
150+ cublasErrcheck (cublasZscal (BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
151+ #endif
152+ }
64153}
65154
66155
@@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
70159 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
71160 return sdot_ (&n, X, &incX, Y, &incY);
72161 }
162+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
163+ #ifdef __CUDA
164+ float result = 0.0 ;
165+ cublasErrcheck (cublasSdot (BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
166+ return result;
167+ #endif
168+ }
73169 return sdot_ (&n, X, &incX, Y, &incY);
74170}
75171
@@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
78174 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
79175 return ddot_ (&n, X, &incX, Y, &incY);
80176 }
177+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
178+ #ifdef __CUDA
179+ double result = 0.0 ;
180+ cublasErrcheck (cublasDdot (BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
181+ return result;
182+ #endif
183+ }
81184 return ddot_ (&n, X, &incX, Y, &incY);
82185}
83186
@@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
92195 &alpha, b, &ldb, a, &lda,
93196 &beta, c, &ldc);
94197 }
95- #ifdef __DSP
198+ #ifdef __DSP
96199 else if (device_type == base_device::AbacusDevice_t::DspDevice){
97200 sgemm_mth_ (&transb, &transa, &n, &m, &k,
98201 &alpha, b, &ldb, a, &lda,
99202 &beta, c, &ldc, GlobalV::MY_RANK);
100203 }
101- #endif
204+ #endif
205+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
206+ #ifdef __CUDA
207+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
208+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
209+ cublasErrcheck (cublasSgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
210+ #endif
211+ }
102212}
103213
104214void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
110220 &alpha, b, &ldb, a, &lda,
111221 &beta, c, &ldc);
112222 }
113- #ifdef __DSP
223+ #ifdef __DSP
114224 else if (device_type == base_device::AbacusDevice_t::DspDevice){
115225 dgemm_mth_ (&transb, &transa, &n, &m, &k,
116226 &alpha, b, &ldb, a, &lda,
117227 &beta, c, &ldc, GlobalV::MY_RANK);
118228 }
119- #endif
229+ #endif
230+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
231+ #ifdef __CUDA
232+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
233+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
234+ cublasErrcheck (cublasDgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
235+ #endif
236+ }
120237}
121238
122239void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
128245 &alpha, b, &ldb, a, &lda,
129246 &beta, c, &ldc);
130247 }
131- #ifdef __DSP
248+ #ifdef __DSP
132249 else if (device_type == base_device::AbacusDevice_t::DspDevice) {
133250 cgemm_mth_ (&transb, &transa, &n, &m, &k,
134251 &alpha, b, &ldb, a, &lda,
135252 &beta, c, &ldc, GlobalV::MY_RANK);
136253 }
137- #endif
254+ #endif
255+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
256+ #ifdef __CUDA
257+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
258+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
259+ cublasErrcheck (cublasCgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
260+ #endif
261+ }
138262}
139263
140264void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
146270 &alpha, b, &ldb, a, &lda,
147271 &beta, c, &ldc);
148272 }
149- #ifdef __DSP
273+ #ifdef __DSP
150274 else if (device_type == base_device::AbacusDevice_t::DspDevice) {
151275 zgemm_mth_ (&transb, &transa, &n, &m, &k,
152276 &alpha, b, &ldb, a, &lda,
153277 &beta, c, &ldc, GlobalV::MY_RANK);
154278 }
155- #endif
279+ #endif
280+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
281+ #ifdef __CUDA
282+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
283+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
284+ cublasErrcheck (cublasZgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
285+ #endif
286+ }
156287}
157288
158289// Col-Major part
@@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
165296 &alpha, a, &lda, b, &ldb,
166297 &beta, c, &ldc);
167298 }
168- #ifdef __DSP
299+ #ifdef __DSP
169300 else if (device_type == base_device::AbacusDevice_t::DspDevice){
170301 sgemm_mth_ (&transb, &transa, &m, &n, &k,
171302 &alpha, a, &lda, b, &ldb,
172303 &beta, c, &ldc, GlobalV::MY_RANK);
173304 }
174- #endif
305+ #endif
306+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
307+ #ifdef __CUDA
308+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
309+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
310+ cublasErrcheck (cublasSgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
311+ #endif
312+ }
175313}
176314
177315void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
@@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
183321 &alpha, a, &lda, b, &ldb,
184322 &beta, c, &ldc);
185323 }
186- #ifdef __DSP
324+ #ifdef __DSP
187325 else if (device_type == base_device::AbacusDevice_t::DspDevice){
188326 dgemm_mth_ (&transa, &transb, &m, &n, &k,
189327 &alpha, a, &lda, b, &ldb,
190328 &beta, c, &ldc, GlobalV::MY_RANK);
191329 }
192- #endif
330+ #endif
331+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
332+ #ifdef __CUDA
333+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
334+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
335+ cublasErrcheck (cublasDgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
336+ #endif
337+ }
193338}
194339
195340void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
@@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
201346 &alpha, a, &lda, b, &ldb,
202347 &beta, c, &ldc);
203348 }
204- #ifdef __DSP
349+ #ifdef __DSP
205350 else if (device_type == base_device::AbacusDevice_t::DspDevice) {
206351 cgemm_mth_ (&transa, &transb, &m, &n, &k,
207352 &alpha, a, &lda, b, &ldb,
208353 &beta, c, &ldc, GlobalV::MY_RANK);
209354 }
210- #endif
355+ #endif
356+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
357+ #ifdef __CUDA
358+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
359+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
360+ cublasErrcheck (cublasCgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
361+ #endif
362+ }
211363}
212364
213365void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
@@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
219371 &alpha, a, &lda, b, &ldb,
220372 &beta, c, &ldc);
221373 }
222- #ifdef __DSP
374+ #ifdef __DSP
223375 else if (device_type == base_device::AbacusDevice_t::DspDevice) {
224376 zgemm_mth_ (&transa, &transb, &m, &n, &k,
225377 &alpha, a, &lda, b, &ldb,
226378 &beta, c, &ldc, GlobalV::MY_RANK);
227379 }
228- #endif
380+ #endif
381+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
382+ #ifdef __CUDA
383+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
384+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
385+ cublasErrcheck (cublasZgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
386+ #endif
387+ }
229388}
230389
231390// Symm and Hemm part. Only col-major is supported.
0 commit comments