1515
1616#include " cublas_v2.h"
1717
18+ namespace BlasUtils {
19+
1820static cublasHandle_t cublas_handle = nullptr ;
1921
2022void createGpuBlasHandle (){
@@ -30,7 +32,9 @@ void destoryBLAShandle(){
3032 }
3133}
3234
33- cublasOperation_t judge_trans_op (bool is_complex, const char & trans, const char * name)
35+ } // namespace BlasUtils
36+
37+ cublasOperation_t judge_trans (bool is_complex, const char & trans, const char * name)
3438{
3539 if (trans == ' N' )
3640 {
@@ -44,10 +48,7 @@ cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char*
4448 {
4549 return CUBLAS_OP_C;
4650 }
47- else
48- {
49- ModuleBase::WARNING_QUIT (name, std::string (" Unknown trans type " ) + trans + std::string (" !" ));
50- }
51+ return CUBLAS_OP_N;
5152}
5253
5354#endif
@@ -59,7 +60,7 @@ void BlasConnector::axpy( const int n, const float alpha, const float *X, const
5960 }
6061 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
6162#ifdef __CUDA
62- cublasErrcheck (cublasSaxpy (cublas_handle, n, &alpha, X, incX, Y, incY));
63+ cublasErrcheck (cublasSaxpy (BlasUtils:: cublas_handle, n, &alpha, X, incX, Y, incY));
6364#endif
6465 }
6566}
@@ -71,7 +72,7 @@ void BlasConnector::axpy( const int n, const double alpha, const double *X, cons
7172 }
7273 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
7374#ifdef __CUDA
74- cublasErrcheck (cublasDaxpy (cublas_handle, n, &alpha, X, incX, Y, incY));
75+ cublasErrcheck (cublasDaxpy (BlasUtils:: cublas_handle, n, &alpha, X, incX, Y, incY));
7576#endif
7677 }
7778}
@@ -83,7 +84,7 @@ void BlasConnector::axpy( const int n, const std::complex<float> alpha, const st
8384 }
8485 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
8586#ifdef __CUDA
86- cublasErrcheck (cublasCaxpy (cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
87+ cublasErrcheck (cublasCaxpy (BlasUtils:: cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
8788#endif
8889 }
8990}
@@ -95,7 +96,7 @@ void BlasConnector::axpy( const int n, const std::complex<double> alpha, const s
9596 }
9697 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
9798#ifdef __CUDA
98- cublasErrcheck (cublasZaxpy (cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
99+ cublasErrcheck (cublasZaxpy (BlasUtils:: cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
99100#endif
100101 }
101102}
@@ -109,7 +110,7 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
109110 }
110111 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
111112#ifdef __CUDA
112- cublasErrcheck (cublasSscal (cublas_handle, n, &alpha, X, incX));
113+ cublasErrcheck (cublasSscal (BlasUtils:: cublas_handle, n, &alpha, X, incX));
113114#endif
114115 }
115116}
@@ -121,7 +122,7 @@ void BlasConnector::scal( const int n, const double alpha, double *X, const int
121122 }
122123 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
123124#ifdef __CUDA
124- cublasErrcheck (cublasDscal (cublas_handle, n, &alpha, X, incX));
125+ cublasErrcheck (cublasDscal (BlasUtils:: cublas_handle, n, &alpha, X, incX));
125126#endif
126127 }
127128}
@@ -133,7 +134,7 @@ void BlasConnector::scal( const int n, const std::complex<float> alpha, std::com
133134 }
134135 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
135136#ifdef __CUDA
136- cublasErrcheck (cublasCscal (cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
137+ cublasErrcheck (cublasCscal (BlasUtils:: cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
137138#endif
138139 }
139140}
@@ -145,7 +146,7 @@ void BlasConnector::scal( const int n, const std::complex<double> alpha, std::co
145146 }
146147 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
147148#ifdef __CUDA
148- cublasErrcheck (cublasZscal (cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
149+ cublasErrcheck (cublasZscal (BlasUtils:: cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
149150#endif
150151 }
151152}
@@ -160,7 +161,7 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
160161 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
161162#ifdef __CUDA
162163 float result = 0.0 ;
163- cublasErrcheck (cublasSdot (cublas_handle, n, X, incX, Y, incY, &result));
164+ cublasErrcheck (cublasSdot (BlasUtils:: cublas_handle, n, X, incX, Y, incY, &result));
164165 return result;
165166#endif
166167 }
@@ -175,7 +176,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
175176 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
176177#ifdef __CUDA
177178 double result = 0.0 ;
178- cublasErrcheck (cublasDdot (cublas_handle, n, X, incX, Y, incY, &result));
179+ cublasErrcheck (cublasDdot (BlasUtils:: cublas_handle, n, X, incX, Y, incY, &result));
179180 return result;
180181#endif
181182 }
@@ -201,9 +202,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
201202#endif
202203 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
203204#ifdef __CUDA
204- cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
205- cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
206- cublasErrcheck (cublasSgemm (cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
205+ cublasOperation_t cutransA = judge_trans (false , transa, " gemm_op" );
206+ cublasOperation_t cutransB = judge_trans (false , transb, " gemm_op" );
207+ cublasErrcheck (cublasSgemm (BlasUtils:: cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
207208#endif
208209 }
209210}
@@ -226,9 +227,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
226227#endif
227228 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
228229#ifdef __CUDA
229- cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
230- cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
231- cublasErrcheck (cublasDgemm (cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
230+ cublasOperation_t cutransA = judge_trans (false , transa, " gemm_op" );
231+ cublasOperation_t cutransB = judge_trans (false , transb, " gemm_op" );
232+ cublasErrcheck (cublasDgemm (BlasUtils:: cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
232233#endif
233234 }
234235}
@@ -251,9 +252,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
251252#endif
252253 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
253254#ifdef __CUDA
254- cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
255- cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
256- cublasErrcheck (cublasCgemm (cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
255+ cublasOperation_t cutransA = judge_trans (false , transa, " gemm_op" );
256+ cublasOperation_t cutransB = judge_trans (false , transb, " gemm_op" );
257+ cublasErrcheck (cublasCgemm (BlasUtils:: cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
257258#endif
258259 }
259260}
@@ -276,9 +277,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
276277#endif
277278 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
278279#ifdef __CUDA
279- cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
280- cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
281- cublasErrcheck (cublasZgemm (cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
280+ cublasOperation_t cutransA = judge_trans (false , transa, " gemm_op" );
281+ cublasOperation_t cutransB = judge_trans (false , transb, " gemm_op" );
282+ cublasErrcheck (cublasZgemm (BlasUtils:: cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
282283#endif
283284 }
284285}
@@ -292,8 +293,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
292293 }
293294 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
294295#ifdef __CUDA
295- cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
296- cublasErrcheck (cublasSgemv (cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
296+ cublasOperation_t cutrans = judge_trans (false , trans, " gemv_op" );
297+ cublasErrcheck (cublasSgemv (BlasUtils:: cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
297298#endif
298299 }
299300}
@@ -307,8 +308,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
307308 }
308309 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
309310#ifdef __CUDA
310- cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
311- cublasErrcheck (cublasDgemv (cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
311+ cublasOperation_t cutrans = judge_trans (false , trans, " gemv_op" );
312+ cublasErrcheck (cublasDgemv (BlasUtils:: cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
312313#endif
313314 }
314315}
@@ -322,8 +323,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
322323 }
323324 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
324325#ifdef __CUDA
325- cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
326- cublasErrcheck (cublasCgemv (cublas_handle, cutrans, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incX, (float2*)&beta, (float2*)Y, incY));
326+ cublasOperation_t cutrans = judge_trans (false , trans, " gemv_op" );
327+ cublasErrcheck (cublasCgemv (BlasUtils:: cublas_handle, cutrans, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incX, (float2*)&beta, (float2*)Y, incY));
327328#endif
328329 }
329330}
@@ -337,8 +338,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
337338 }
338339 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
339340#ifdef __CUDA
340- cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
341- cublasErrcheck (cublasZgemv (cublas_handle, cutrans, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incX, (double2*)&beta, (double2*)Y, incY));
341+ cublasOperation_t cutrans = judge_trans (false , trans, " gemv_op" );
342+ cublasErrcheck (cublasZgemv (BlasUtils:: cublas_handle, cutrans, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incX, (double2*)&beta, (double2*)Y, incY));
342343#endif
343344 }
344345}
0 commit comments