@@ -50,6 +50,36 @@ namespace BlasUtils{
5050 return CUBLAS_OP_N;
5151 }
5252
53+ cublasSideMode_t judge_side (const char & trans)
54+ {
55+ if (trans == " L" )
56+ {
57+ return CUBLAS_SIDE_LEFT;
58+ }
59+ else if (trans == " R" )
60+ {
61+ return CUBLAS_SIDE_RIGHT;
62+ }
63+ return CUBLAS_SIDE_LEFT;
64+ }
65+
66+ cublasFillMode_t judge_fill (const char & trans)
67+ {
68+ if (trans == " F" )
69+ {
70+ return CUBLAS_FILL_MODE_FULL;
71+ }
72+ else if (trans == " U" )
73+ {
74+ return CUBLAS_FILL_MODE_UPPER;
75+ }
76+ else if (trans == " D" )
77+ {
78+ return CUBLAS_FILL_MODE_LOWER;
79+ }
80+ return CUBLAS_FILL_MODE_FULL;
81+ }
82+
5383} // namespace BlasUtils
5484
5585#endif
@@ -184,6 +214,22 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
184214 return ddot_ (&n, X, &incX, Y, &incY);
185215}
186216
217+ double BlasConnector::dot ( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
218+ {
219+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
220+ return ddot_ (&n, X, &incX, Y, &incY);
221+ }
222+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
223+ #ifdef __CUDA
224+ double result = 0.0 ;
225+ cublasErrcheck (cublasDdot (BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
226+ return result;
227+ #endif
228+ }
229+ return ddot_ (&n, X, &incX, Y, &incY);
230+ }
231+
232+
187233// C = a * A.? * B.? + b * C
188234// Row-Major part
189235void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -398,6 +444,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
398444 &alpha, a, &lda, b, &ldb,
399445 &beta, c, &ldc);
400446 }
447+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
448+ #ifdef __CUDA
449+ cublasSideMode_t sideMode = BlasUtils::judge_side (transa);
450+ cublasOperation_t fillMode = BlasUtils::judge_fill (transb);
451+ cublasErrcheck (cublasSsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
452+ #endif
453+ }
401454}
402455
403456void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -409,6 +462,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
409462 &alpha, a, &lda, b, &ldb,
410463 &beta, c, &ldc);
411464 }
465+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
466+ #ifdef __CUDA
467+ cublasSideMode_t sideMode = BlasUtils::judge_side (transa);
468+ cublasOperation_t fillMode = BlasUtils::judge_fill (transb);
469+ cublasErrcheck (cublasDsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
470+ #endif
471+ }
412472}
413473
414474void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -420,6 +480,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
420480 &alpha, a, &lda, b, &ldb,
421481 &beta, c, &ldc);
422482 }
483+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
484+ #ifdef __CUDA
485+ cublasSideMode_t sideMode = BlasUtils::judge_side (transa);
486+ cublasOperation_t fillMode = BlasUtils::judge_fill (transb);
487+ cublasErrcheck (cublasCsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
488+ #endif
489+ }
423490}
424491
425492void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -431,6 +498,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
431498 &alpha, a, &lda, b, &ldb,
432499 &beta, c, &ldc);
433500 }
501+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
502+ #ifdef __CUDA
503+ cublasSideMode_t sideMode = BlasUtils::judge_side (transa);
504+ cublasOperation_t fillMode = BlasUtils::judge_fill (transb);
505+ cublasErrcheck (cublasZsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
506+ #endif
507+ }
434508}
435509
436510void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
@@ -442,6 +516,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
442516 &alpha, a, &lda, b, &ldb,
443517 &beta, c, &ldc);
444518 }
519+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
520+ #ifdef __CUDA
521+ cublasSideMode_t sideMode = BlasUtils::judge_side (transa);
522+ cublasOperation_t fillMode = BlasUtils::judge_fill (transb);
523+ cublasErrcheck (cublasChemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
524+ #endif
525+ }
445526}
446527
447528void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
@@ -453,6 +534,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
453534 &alpha, a, &lda, b, &ldb,
454535 &beta, c, &ldc);
455536 }
537+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538+ #ifdef __CUDA
539+ cublasSideMode_t sideMode = BlasUtils::judge_side (transa);
540+ cublasOperation_t fillMode = BlasUtils::judge_fill (transb);
541+ cublasErrcheck (cublasZhemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
542+ #endif
543+ }
456544}
457545
458546void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -461,7 +549,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
461549{
462550 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
463551 sgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
464- }
552+ }
553+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
554+ #ifdef __CUDA
555+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemv_op" );
556+ cublasErrcheck (cublasSgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
557+ #endif
558+ }
465559}
466560
467561void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -470,7 +564,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
470564{
471565 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
472566 dgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
473- }
567+ }
568+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
569+ #ifdef __CUDA
570+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemv_op" );
571+ cublasErrcheck (cublasDgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
572+ #endif
573+ }
474574}
475575
476576void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -479,7 +579,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
479579{
480580 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
481581 cgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
482- }
582+ }
583+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
584+ #ifdef __CUDA
585+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , transa, " gemv_op" );
586+ cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incx, (float2*)&beta, (float2*)Y, incy));
587+ #endif
588+ }
483589}
484590
485591void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -488,7 +594,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
488594{
489595 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
490596 zgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
491- }
597+ }
598+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
599+ #ifdef __CUDA
600+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , transa, " gemv_op" );
601+ cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incx, (double2*)&beta, (double2*)Y, incy));
602+ #endif
603+ }
492604}
493605
494606// out = ||x||_2
@@ -497,6 +609,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
497609 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
498610 return snrm2_ ( &n, X, &incX );
499611 }
612+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
613+ #ifdef __CUDA
614+ float result = 0.0 ;
615+ cublasErrcheck (cublasSnrm2 (BlasUtils::cublas_handle, n, X, incX, &result));
616+ return result;
617+ #endif
618+ }
500619 return snrm2_ ( &n, X, &incX );
501620}
502621
@@ -506,6 +625,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
506625 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
507626 return dnrm2_ ( &n, X, &incX );
508627 }
628+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
629+ #ifdef __CUDA
630+ double result = 0.0 ;
631+ cublasErrcheck (cublasDnrm2 (BlasUtils::cublas_handle, n, X, incX, &result));
632+ return result;
633+ #endif
634+ }
509635 return dnrm2_ ( &n, X, &incX );
510636}
511637
@@ -515,6 +641,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
515641 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
516642 return dznrm2_ ( &n, X, &incX );
517643 }
644+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
645+ #ifdef __CUDA
646+ std::complex <double > result = 0.0 ;
647+ cublasErrcheck (cublasDznrm2 (BlasUtils::cublas_handle, n, (double2*)X, incX, (double2*)&result));
648+ return result;
649+ #endif
650+ }
518651 return dznrm2_ ( &n, X, &incX );
519652}
520653
0 commit comments