@@ -52,11 +52,11 @@ namespace BlasUtils{
5252
5353 cublasSideMode_t judge_side (const char & trans)
5454 {
55- if (trans == " L " )
55+ if (trans == ' L ' )
5656 {
5757 return CUBLAS_SIDE_LEFT;
5858 }
59- else if (trans == " R " )
59+ else if (trans == ' R ' )
6060 {
6161 return CUBLAS_SIDE_RIGHT;
6262 }
@@ -65,15 +65,15 @@ namespace BlasUtils{
6565
6666 cublasFillMode_t judge_fill (const char & trans)
6767 {
68- if (trans == " F " )
68+ if (trans == ' F ' )
6969 {
7070 return CUBLAS_FILL_MODE_FULL;
7171 }
72- else if (trans == " U " )
72+ else if (trans == ' U ' )
7373 {
7474 return CUBLAS_FILL_MODE_UPPER;
7575 }
76- else if (trans == " D " )
76+ else if (trans == ' D ' )
7777 {
7878 return CUBLAS_FILL_MODE_LOWER;
7979 }
@@ -430,9 +430,9 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
430430 }
431431 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
432432#ifdef __CUDA
433- cublasSideMode_t sideMode = BlasUtils::judge_side (transa );
434- cublasOperation_t fillMode = BlasUtils::judge_fill (transb );
435- cublasErrcheck (cublasSsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*) &alpha, (double2*) a, lda, (double2*) b, ldb, (double2*) &beta, (double2*) c, ldc));
433+ cublasSideMode_t sideMode = BlasUtils::judge_side (side );
434+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo );
435+ cublasErrcheck (cublasSsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
436436#endif
437437 }
438438}
@@ -448,8 +448,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
448448 }
449449 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
450450#ifdef __CUDA
451- cublasSideMode_t sideMode = BlasUtils::judge_side (transa );
452- cublasOperation_t fillMode = BlasUtils::judge_fill (transb );
451+ cublasSideMode_t sideMode = BlasUtils::judge_side (side );
452+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo );
453453 cublasErrcheck (cublasDsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
454454#endif
455455 }
@@ -466,8 +466,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
466466 }
467467 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
468468#ifdef __CUDA
469- cublasSideMode_t sideMode = BlasUtils::judge_side (transa );
470- cublasOperation_t fillMode = BlasUtils::judge_fill (transb );
469+ cublasSideMode_t sideMode = BlasUtils::judge_side (side );
470+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo );
471471 cublasErrcheck (cublasCsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
472472#endif
473473 }
@@ -484,8 +484,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
484484 }
485485 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
486486#ifdef __CUDA
487- cublasSideMode_t sideMode = BlasUtils::judge_side (transa );
488- cublasOperation_t fillMode = BlasUtils::judge_fill (transb );
487+ cublasSideMode_t sideMode = BlasUtils::judge_side (side );
488+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo );
489489 cublasErrcheck (cublasZsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
490490#endif
491491 }
@@ -502,8 +502,8 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
502502 }
503503 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
504504#ifdef __CUDA
505- cublasSideMode_t sideMode = BlasUtils::judge_side (transa );
506- cublasOperation_t fillMode = BlasUtils::judge_fill (transb );
505+ cublasSideMode_t sideMode = BlasUtils::judge_side (side );
506+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo );
507507 cublasErrcheck (cublasChemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
508508#endif
509509 }
@@ -520,8 +520,8 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
520520 }
521521 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
522522#ifdef __CUDA
523- cublasSideMode_t sideMode = BlasUtils::judge_side (transa );
524- cublasOperation_t fillMode = BlasUtils::judge_fill (transb );
523+ cublasSideMode_t sideMode = BlasUtils::judge_side (side );
524+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo );
525525 cublasErrcheck (cublasZhemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
526526#endif
527527 }
@@ -536,7 +536,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
536536 }
537537 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538538#ifdef __CUDA
539- cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa , " gemv_op" );
539+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , trans , " gemv_op" );
540540 cublasErrcheck (cublasSgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
541541#endif
542542 }
@@ -551,7 +551,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
551551 }
552552 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
553553#ifdef __CUDA
554- cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa , " gemv_op" );
554+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , trans , " gemv_op" );
555555 cublasErrcheck (cublasDgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
556556#endif
557557 }
@@ -566,8 +566,10 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
566566 }
567567 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
568568#ifdef __CUDA
569- cublasOperation_t cutransA = BlasUtils::judge_trans (true , transa, " gemv_op" );
570- cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incx, (float2*)&beta, (float2*)Y, incy));
569+ cuFloatComplex alpha_cu = make_cuFloatComplex (alpha.real (), alpha.imag ());
570+ cuFloatComplex beta_cu = make_cuFloatComplex (beta.real (), beta.imag ());
571+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , trans, " gemv_op" );
572+ cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
571573#endif
572574 }
573575}
@@ -581,8 +583,10 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
581583 }
582584 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
583585#ifdef __CUDA
584- cublasOperation_t cutransA = BlasUtils::judge_trans (true , transa, " gemv_op" );
585- cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incx, (double2*)&beta, (double2*)Y, incy));
586+ cuDoubleComplex alpha_cu = make_cuDoubleComplex (alpha.real (), alpha.imag ());
587+ cuDoubleComplex beta_cu = make_cuDoubleComplex (beta.real (), beta.imag ());
588+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , trans, " gemv_op" );
589+ cublasErrcheck (cublasZgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
586590#endif
587591 }
588592}
@@ -627,8 +631,8 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
627631 }
628632 else if (device_type == base_device::AbacusDevice_t::GpuDevice){
629633#ifdef __CUDA
630- std:: complex < double > result = 0.0 ;
631- cublasErrcheck (cublasDznrm2 (BlasUtils::cublas_handle, n, (double2*)X, incX, (double2*) &result));
634+ double result = 0.0 ;
635+ cublasErrcheck (cublasDznrm2 (BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
632636 return result;
633637#endif
634638 }
0 commit comments