Skip to content

Commit 2c9b933

Browse files
committed
INITIAL COMMIT
1 parent 80ab641 commit 2c9b933

File tree

5 files changed

+1150
-3
lines changed

5 files changed

+1150
-3
lines changed

source/module_base/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ add_library(
1010
base
1111
OBJECT
1212
assoc_laguerre.cpp
13-
blas_connector.cpp
13+
#blas_connector.cpp
14+
blas_connector_base.cpp
15+
blas_connector_vector.cpp
16+
blas_connector_matrix.cpp
1417
clebsch_gordan_coeff.cpp
1518
complexarray.cpp
1619
complexmatrix.cpp

source/module_base/blas_connector.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,26 @@ class BlasConnector
368368

369369
#ifdef __CUDA
370370

371+
#include <cuda_runtime.h>
372+
#include "cublas_v2.h"
373+
374+
// If you want to use cublas, you need these functions to create and destroy the cublas/hipblas handle.
375+
// You also need to use these functions to translate the transpose parameter into cublas/hipblas datatype.
376+
371377
namespace BlasUtils{
372-
void createGpuBlasHandle();
373-
void destoryBLAShandle();
378+
379+
static cublasHandle_t cublas_handle = nullptr;
380+
381+
void createGpuBlasHandle(); // Create a cublas/hipblas handle.
382+
383+
void destoryBLAShandle(); // Destroy the cublas/hipblas handle. Do this when the software is about to end.
384+
385+
cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name); // Translate a normal transpose parameter to a cublas/hipblas type.
386+
387+
cublasSideMode_t judge_side(const char& trans); // Translate a normal side parameter to a cublas/hipblas type.
388+
389+
cublasFillMode_t judge_fill(const char& trans); // Translate a normal fill parameter to a cublas/hipblas type.
390+
374391
}
375392

376393
#endif
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "blas_connector.h"
2+
#include "macros.h"
3+
4+
#ifdef __CUDA
5+
#include <base/macros/macros.h>
6+
#include <cuda_runtime.h>
7+
#include "cublas_v2.h"
8+
#include "module_base/kernels/math_kernel_op.h"
9+
#include "module_base/module_device/memory_op.h"
10+
11+
12+
namespace BlasUtils{
13+
14+
static cublasHandle_t cublas_handle = nullptr;
15+
16+
void createGpuBlasHandle(){
17+
if (cublas_handle == nullptr) {
18+
cublasErrcheck(cublasCreate(&cublas_handle));
19+
}
20+
}
21+
22+
void destoryBLAShandle(){
23+
if (cublas_handle != nullptr) {
24+
cublasErrcheck(cublasDestroy(cublas_handle));
25+
cublas_handle = nullptr;
26+
}
27+
}
28+
29+
30+
cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
31+
{
32+
if (trans == 'N')
33+
{
34+
return CUBLAS_OP_N;
35+
}
36+
else if(trans == 'T')
37+
{
38+
return CUBLAS_OP_T;
39+
}
40+
else if(is_complex && trans == 'C')
41+
{
42+
return CUBLAS_OP_C;
43+
}
44+
return CUBLAS_OP_N;
45+
}
46+
47+
cublasSideMode_t judge_side(const char& trans)
48+
{
49+
if (trans == 'L')
50+
{
51+
return CUBLAS_SIDE_LEFT;
52+
}
53+
else if (trans == 'R')
54+
{
55+
return CUBLAS_SIDE_RIGHT;
56+
}
57+
return CUBLAS_SIDE_LEFT;
58+
}
59+
60+
cublasFillMode_t judge_fill(const char& trans)
61+
{
62+
if (trans == 'F')
63+
{
64+
return CUBLAS_FILL_MODE_FULL;
65+
}
66+
else if (trans == 'U')
67+
{
68+
return CUBLAS_FILL_MODE_UPPER;
69+
}
70+
else if (trans == 'D')
71+
{
72+
return CUBLAS_FILL_MODE_LOWER;
73+
}
74+
return CUBLAS_FILL_MODE_FULL;
75+
}
76+
77+
} // namespace BlasUtils
78+
79+
#endif

0 commit comments

Comments
 (0)