Skip to content

Commit 583920a

Browse files
committed
Fix trans comparison bug
1 parent 0e41185 commit 583920a

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

source/module_base/blas_connector.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ namespace BlasUtils{
5252

5353
cublasSideMode_t judge_side(const char& trans)
5454
{
55-
if (trans == "L")
55+
if (trans == 'L')
5656
{
5757
return CUBLAS_SIDE_LEFT;
5858
}
59-
else if (trans == "R")
59+
else if (trans == 'R')
6060
{
6161
return CUBLAS_SIDE_RIGHT;
6262
}
@@ -65,15 +65,15 @@ namespace BlasUtils{
6565

6666
cublasFillMode_t judge_fill(const char& trans)
6767
{
68-
if (trans == "F")
68+
if (trans == 'F')
6969
{
7070
return CUBLAS_FILL_MODE_FULL;
7171
}
72-
else if (trans == "U")
72+
else if (trans == 'U')
7373
{
7474
return CUBLAS_FILL_MODE_UPPER;
7575
}
76-
else if (trans == "D")
76+
else if (trans == 'D')
7777
{
7878
return CUBLAS_FILL_MODE_LOWER;
7979
}
@@ -430,9 +430,9 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
430430
}
431431
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
432432
#ifdef __CUDA
433-
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
434-
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
435-
cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
433+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
434+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
435+
cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
436436
#endif
437437
}
438438
}
@@ -448,8 +448,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
448448
}
449449
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
450450
#ifdef __CUDA
451-
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
452-
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
451+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
452+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
453453
cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
454454
#endif
455455
}
@@ -466,8 +466,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
466466
}
467467
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
468468
#ifdef __CUDA
469-
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
470-
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
469+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
470+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
471471
cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
472472
#endif
473473
}
@@ -484,8 +484,8 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
484484
}
485485
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
486486
#ifdef __CUDA
487-
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
488-
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
487+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
488+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
489489
cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
490490
#endif
491491
}
@@ -502,8 +502,8 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
502502
}
503503
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
504504
#ifdef __CUDA
505-
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
506-
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
505+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
506+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
507507
cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
508508
#endif
509509
}
@@ -520,8 +520,8 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
520520
}
521521
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
522522
#ifdef __CUDA
523-
cublasSideMode_t sideMode = BlasUtils::judge_side(transa);
524-
cublasOperation_t fillMode = BlasUtils::judge_fill(transb);
523+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
524+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
525525
cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
526526
#endif
527527
}
@@ -536,7 +536,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
536536
}
537537
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538538
#ifdef __CUDA
539-
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op");
539+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
540540
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
541541
#endif
542542
}
@@ -551,7 +551,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
551551
}
552552
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
553553
#ifdef __CUDA
554-
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemv_op");
554+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
555555
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
556556
#endif
557557
}
@@ -566,8 +566,10 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
566566
}
567567
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
568568
#ifdef __CUDA
569-
cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op");
570-
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incx, (float2*)&beta, (float2*)Y, incy));
569+
cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag());
570+
cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag());
571+
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
572+
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
571573
#endif
572574
}
573575
}
@@ -581,8 +583,10 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
581583
}
582584
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
583585
#ifdef __CUDA
584-
cublasOperation_t cutransA = BlasUtils::judge_trans(true, transa, "gemv_op");
585-
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incx, (double2*)&beta, (double2*)Y, incy));
586+
cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag());
587+
cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag());
588+
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
589+
cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
586590
#endif
587591
}
588592
}
@@ -627,8 +631,8 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
627631
}
628632
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
629633
#ifdef __CUDA
630-
std::complex<double> result = 0.0;
631-
cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, (double2*)&result));
634+
double result = 0.0;
635+
cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
632636
return result;
633637
#endif
634638
}

0 commit comments

Comments
 (0)