Skip to content

Commit 7a0cbe9

Browse files
authored
[Refactor] Separate blas_connector.cpp into three files (#6178)
* INITIAL COMMIT * MODIFY UNIT TESTS * REMOVE BLAS_CONNECTOR.CPP * FIX MAKEFILE
1 parent 0dc5e8b commit 7a0cbe9

File tree

13 files changed

+590
-537
lines changed

13 files changed

+590
-537
lines changed

source/Makefile.Objects

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ OBJS_MAIN=main.o\
126126

127127
OBJS_BASE=abfs-vector3_order.o\
128128
assoc_laguerre.o\
129-
blas_connector.o\
129+
blas_connector_base.o\
130+
blas_connector_vector.o\
131+
blas_connector_matrix.o\
130132
complexarray.o\
131133
complexmatrix.o\
132134
clebsch_gordan_coeff.o\

source/module_base/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ add_library(
1010
base
1111
OBJECT
1212
assoc_laguerre.cpp
13-
blas_connector.cpp
13+
blas_connector_base.cpp
14+
blas_connector_vector.cpp
15+
blas_connector_matrix.cpp
1416
clebsch_gordan_coeff.cpp
1517
complexarray.cpp
1618
complexmatrix.cpp

source/module_base/blas_connector.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,26 @@ class BlasConnector
368368

369369
#ifdef __CUDA
370370

371+
#include <cuda_runtime.h>
372+
#include "cublas_v2.h"
373+
374+
// If you want to use cublas, you need these functions to create and destroy the cublas/hipblas handle.
375+
// You also need to use these functions to translate the transpose parameter into cublas/hipblas datatype.
376+
371377
namespace BlasUtils{
372-
void createGpuBlasHandle();
373-
void destoryBLAShandle();
378+
379+
static cublasHandle_t cublas_handle = nullptr;
380+
381+
void createGpuBlasHandle(); // Create a cublas/hipblas handle.
382+
383+
void destoryBLAShandle(); // Destroy the cublas/hipblas handle. Do this when the software is about to end.
384+
385+
cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name); // Translate a normal transpose parameter to a cublas/hipblas type.
386+
387+
cublasSideMode_t judge_side(const char& trans); // Translate a normal side parameter to a cublas/hipblas type.
388+
389+
cublasFillMode_t judge_fill(const char& trans); // Translate a normal fill parameter to a cublas/hipblas type.
390+
374391
}
375392

376393
#endif
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "blas_connector.h"
2+
#include "macros.h"
3+
4+
#ifdef __CUDA
5+
#include <base/macros/macros.h>
6+
#include <cuda_runtime.h>
7+
#include "cublas_v2.h"
8+
#include "module_base/kernels/math_kernel_op.h"
9+
#include "module_base/module_device/memory_op.h"
10+
11+
12+
namespace BlasUtils{
13+
14+
void createGpuBlasHandle(){
15+
if (cublas_handle == nullptr) {
16+
cublasErrcheck(cublasCreate(&cublas_handle));
17+
}
18+
}
19+
20+
void destoryBLAShandle(){
21+
if (cublas_handle != nullptr) {
22+
cublasErrcheck(cublasDestroy(cublas_handle));
23+
cublas_handle = nullptr;
24+
}
25+
}
26+
27+
28+
cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
29+
{
30+
if (trans == 'N')
31+
{
32+
return CUBLAS_OP_N;
33+
}
34+
else if(trans == 'T')
35+
{
36+
return CUBLAS_OP_T;
37+
}
38+
else if(is_complex && trans == 'C')
39+
{
40+
return CUBLAS_OP_C;
41+
}
42+
return CUBLAS_OP_N;
43+
}
44+
45+
cublasSideMode_t judge_side(const char& trans)
46+
{
47+
if (trans == 'L')
48+
{
49+
return CUBLAS_SIDE_LEFT;
50+
}
51+
else if (trans == 'R')
52+
{
53+
return CUBLAS_SIDE_RIGHT;
54+
}
55+
return CUBLAS_SIDE_LEFT;
56+
}
57+
58+
cublasFillMode_t judge_fill(const char& trans)
59+
{
60+
if (trans == 'F')
61+
{
62+
return CUBLAS_FILL_MODE_FULL;
63+
}
64+
else if (trans == 'U')
65+
{
66+
return CUBLAS_FILL_MODE_UPPER;
67+
}
68+
else if (trans == 'D')
69+
{
70+
return CUBLAS_FILL_MODE_LOWER;
71+
}
72+
return CUBLAS_FILL_MODE_FULL;
73+
}
74+
75+
} // namespace BlasUtils
76+
77+
#endif

0 commit comments

Comments
 (0)