Skip to content

Commit 55461ee

Browse files
committed
initial commit
1 parent 8407ee9 commit 55461ee

File tree

1 file changed

+137
-4
lines changed

1 file changed

+137
-4
lines changed

source/module_base/blas_connector.cpp

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ namespace BlasUtils{
5050
return CUBLAS_OP_N;
5151
}
5252

53+
cublasSideMode_t judge_side(const char& trans)
54+
{
55+
if (trans == "L")
56+
{
57+
return CUBLAS_SIDE_LEFT;
58+
}
59+
else if (trans == "R")
60+
{
61+
return CUBLAS_SIDE_RIGHT;
62+
}
63+
return CUBLAS_SIDE_LEFT;
64+
}
65+
66+
cublasFillMode_t judge_fill(const char& trans)
67+
{
68+
if (trans == "F")
69+
{
70+
return CUBLAS_FILL_MODE_FULL;
71+
}
72+
else if (trans == "U")
73+
{
74+
return CUBLAS_FILL_MODE_UPPER;
75+
}
76+
else if (trans == "D")
77+
{
78+
return CUBLAS_FILL_MODE_LOWER;
79+
}
80+
return CUBLAS_FILL_MODE_FULL;
81+
}
82+
5383
} // namespace BlasUtils
5484

5585
#endif
@@ -184,6 +214,22 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
184214
return ddot_(&n, X, &incX, Y, &incY);
185215
}
186216

217+
double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
218+
{
219+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
220+
return ddot_(&n, X, &incX, Y, &incY);
221+
}
222+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
223+
#ifdef __CUDA
224+
double result = 0.0;
225+
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
226+
return result;
227+
#endif
228+
}
229+
return ddot_(&n, X, &incX, Y, &incY);
230+
}
231+
232+
187233
// C = a * A.? * B.? + b * C
188234
// Row-Major part
189235
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -398,6 +444,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
398444
&alpha, a, &lda, b, &ldb,
399445
&beta, c, &ldc);
400446
}
447+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
448+
#ifdef __CUDA
449+
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
450+
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
451+
cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
452+
#endif
453+
}
401454
}
402455

403456
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
@@ -409,6 +462,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
409462
&alpha, a, &lda, b, &ldb,
410463
&beta, c, &ldc);
411464
}
465+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
466+
#ifdef __CUDA
467+
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
468+
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
469+
cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
470+
#endif
471+
}
412472
}
413473

414474
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
@@ -420,6 +480,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
420480
&alpha, a, &lda, b, &ldb,
421481
&beta, c, &ldc);
422482
}
483+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
484+
#ifdef __CUDA
485+
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
486+
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
487+
cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
488+
#endif
489+
}
423490
}
424491

425492
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
@@ -431,6 +498,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
431498
&alpha, a, &lda, b, &ldb,
432499
&beta, c, &ldc);
433500
}
501+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
502+
#ifdef __CUDA
503+
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
504+
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
505+
cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
506+
#endif
507+
}
434508
}
435509

436510
void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
@@ -442,6 +516,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
442516
&alpha, a, &lda, b, &ldb,
443517
&beta, c, &ldc);
444518
}
519+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
520+
#ifdef __CUDA
521+
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
522+
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
523+
cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
524+
#endif
525+
}
445526
}
446527

447528
void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
@@ -453,6 +534,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
453534
&alpha, a, &lda, b, &ldb,
454535
&beta, c, &ldc);
455536
}
537+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538+
#ifdef __CUDA
539+
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
540+
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
541+
cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
542+
#endif
543+
}
456544
}
457545

458546
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -461,7 +549,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
461549
{
462550
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
463551
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
464-
}
552+
}
553+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
554+
#ifdef __CUDA
555+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op");
556+
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
557+
#endif
558+
}
465559
}
466560

467561
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -470,7 +564,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
470564
{
471565
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
472566
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
473-
}
567+
}
568+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
569+
#ifdef __CUDA
570+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op");
571+
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
572+
#endif
573+
}
474574
}
475575

476576
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -479,7 +579,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
479579
{
480580
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
481581
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
482-
}
582+
}
583+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
584+
#ifdef __CUDA
585+
cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op");
586+
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incx, (float2*)&beta, (float2*)Y, incy));
587+
#endif
588+
}
483589
}
484590

485591
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -488,7 +594,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
488594
{
489595
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
490596
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
491-
}
597+
}
598+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
599+
#ifdef __CUDA
600+
cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op");
601+
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incx, (double2*)&beta, (double2*)Y, incy));
602+
#endif
603+
}
492604
}
493605

494606
// out = ||x||_2
@@ -497,6 +609,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
497609
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
498610
return snrm2_( &n, X, &incX );
499611
}
612+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
613+
#ifdef __CUDA
614+
float result = 0.0;
615+
cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
616+
return result;
617+
#endif
618+
}
500619
return snrm2_( &n, X, &incX );
501620
}
502621

@@ -506,6 +625,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
506625
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
507626
return dnrm2_( &n, X, &incX );
508627
}
628+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
629+
#ifdef __CUDA
630+
double result = 0.0;
631+
cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
632+
return result;
633+
#endif
634+
}
509635
return dnrm2_( &n, X, &incX );
510636
}
511637

@@ -515,6 +641,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
515641
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
516642
return dznrm2_( &n, X, &incX );
517643
}
644+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
645+
#ifdef __CUDA
646+
std::complex<double> result = 0.0;
647+
cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, (double2*)&result));
648+
return result;
649+
#endif
650+
}
518651
return dznrm2_( &n, X, &incX );
519652
}
520653

0 commit comments

Comments
 (0)