@@ -50,6 +50,36 @@ namespace BlasUtils{
5050 return CUBLAS_OP_N;
5151 }
5252
53+ cublasSideMode_t judge_side (const char & trans)
54+ {
55+ if (trans == ' L' )
56+ {
57+ return CUBLAS_SIDE_LEFT;
58+ }
59+ else if (trans == ' R' )
60+ {
61+ return CUBLAS_SIDE_RIGHT;
62+ }
63+ return CUBLAS_SIDE_LEFT;
64+ }
65+
66+ cublasFillMode_t judge_fill (const char & trans)
67+ {
68+ if (trans == ' F' )
69+ {
70+ return CUBLAS_FILL_MODE_FULL;
71+ }
72+ else if (trans == ' U' )
73+ {
74+ return CUBLAS_FILL_MODE_UPPER;
75+ }
76+ else if (trans == ' D' )
77+ {
78+ return CUBLAS_FILL_MODE_LOWER;
79+ }
80+ return CUBLAS_FILL_MODE_FULL;
81+ }
82+
5383} // namespace BlasUtils
5484
5585#endif
@@ -398,6 +428,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
398428 &alpha, a, &lda, b, &ldb,
399429 &beta, c, &ldc);
400430 }
431+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
432+ #ifdef __CUDA
433+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
434+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
435+ cublasErrcheck (cublasSsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
436+ #endif
437+ }
401438}
402439
403440void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -409,6 +446,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
409446 &alpha, a, &lda, b, &ldb,
410447 &beta, c, &ldc);
411448 }
449+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
450+ #ifdef __CUDA
451+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
452+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
453+ cublasErrcheck (cublasDsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
454+ #endif
455+ }
412456}
413457
414458void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -420,6 +464,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
420464 &alpha, a, &lda, b, &ldb,
421465 &beta, c, &ldc);
422466 }
467+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
468+ #ifdef __CUDA
469+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
470+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
471+ cublasErrcheck (cublasCsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
472+ #endif
473+ }
423474}
424475
425476void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -431,6 +482,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
431482 &alpha, a, &lda, b, &ldb,
432483 &beta, c, &ldc);
433484 }
485+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
486+ #ifdef __CUDA
487+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
488+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
489+ cublasErrcheck (cublasZsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
490+ #endif
491+ }
434492}
435493
436494void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
@@ -442,6 +500,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
442500 &alpha, a, &lda, b, &ldb,
443501 &beta, c, &ldc);
444502 }
503+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
504+ #ifdef __CUDA
505+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
506+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
507+ cublasErrcheck (cublasChemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
508+ #endif
509+ }
445510}
446511
447512void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
@@ -453,6 +518,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
453518 &alpha, a, &lda, b, &ldb,
454519 &beta, c, &ldc);
455520 }
521+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
522+ #ifdef __CUDA
523+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
524+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
525+ cublasErrcheck (cublasZhemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
526+ #endif
527+ }
456528}
457529
458530void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -461,7 +533,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
461533{
462534 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
463535 sgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
464- }
536+ }
537+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538+ #ifdef __CUDA
539+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , trans, " gemv_op" );
540+ cublasErrcheck (cublasSgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
541+ #endif
542+ }
465543}
466544
467545void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -470,7 +548,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
470548{
471549 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
472550 dgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
473- }
551+ }
552+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
553+ #ifdef __CUDA
554+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , trans, " gemv_op" );
555+ cublasErrcheck (cublasDgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
556+ #endif
557+ }
474558}
475559
476560void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -479,7 +563,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
479563{
480564 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
481565 cgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
482- }
566+ }
567+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
568+ #ifdef __CUDA
569+ cuFloatComplex alpha_cu = make_cuFloatComplex (alpha.real (), alpha.imag ());
570+ cuFloatComplex beta_cu = make_cuFloatComplex (beta.real (), beta.imag ());
571+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , trans, " gemv_op" );
572+ cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
573+ #endif
574+ }
483575}
484576
485577void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -488,7 +580,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
488580{
489581 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
490582 zgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
491- }
583+ }
584+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
585+ #ifdef __CUDA
586+ cuDoubleComplex alpha_cu = make_cuDoubleComplex (alpha.real (), alpha.imag ());
587+ cuDoubleComplex beta_cu = make_cuDoubleComplex (beta.real (), beta.imag ());
588+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , trans, " gemv_op" );
589+ cublasErrcheck (cublasZgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
590+ #endif
591+ }
492592}
493593
494594// out = ||x||_2
@@ -497,6 +597,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
497597 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
498598 return snrm2_ ( &n, X, &incX );
499599 }
600+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
601+ #ifdef __CUDA
602+ float result = 0.0 ;
603+ cublasErrcheck (cublasSnrm2 (BlasUtils::cublas_handle, n, X, incX, &result));
604+ return result;
605+ #endif
606+ }
500607 return snrm2_ ( &n, X, &incX );
501608}
502609
@@ -506,6 +613,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
506613 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
507614 return dnrm2_ ( &n, X, &incX );
508615 }
616+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
617+ #ifdef __CUDA
618+ double result = 0.0 ;
619+ cublasErrcheck (cublasDnrm2 (BlasUtils::cublas_handle, n, X, incX, &result));
620+ return result;
621+ #endif
622+ }
509623 return dnrm2_ ( &n, X, &incX );
510624}
511625
@@ -515,6 +629,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
515629 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
516630 return dznrm2_ ( &n, X, &incX );
517631 }
632+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
633+ #ifdef __CUDA
634+ double result = 0.0 ;
635+ cublasErrcheck (cublasDznrm2 (BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
636+ return result;
637+ #endif
638+ }
518639 return dznrm2_ ( &n, X, &incX );
519640}
520641
0 commit comments