Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ OBJS_MAIN=main.o\

OBJS_BASE=abfs-vector3_order.o\
assoc_laguerre.o\
blas_connector.o\
blas_connector_base.o\
blas_connector_vector.o\
blas_connector_matrix.o\
complexarray.o\
complexmatrix.o\
clebsch_gordan_coeff.o\
Expand Down
4 changes: 3 additions & 1 deletion source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ add_library(
base
OBJECT
assoc_laguerre.cpp
blas_connector.cpp
blas_connector_base.cpp
blas_connector_vector.cpp
blas_connector_matrix.cpp
clebsch_gordan_coeff.cpp
complexarray.cpp
complexmatrix.cpp
Expand Down
21 changes: 19 additions & 2 deletions source/module_base/blas_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,26 @@ class BlasConnector

#ifdef __CUDA

#include <cuda_runtime.h>
#include "cublas_v2.h"

// If you want to use cublas, you need these functions to create and destroy the cublas/hipblas handle.
// You also need to use these functions to translate the transpose parameter into cublas/hipblas datatype.

namespace BlasUtils{
void createGpuBlasHandle();
void destoryBLAShandle();

static cublasHandle_t cublas_handle = nullptr;

void createGpuBlasHandle(); // Create a cublas/hipblas handle.

void destoryBLAShandle(); // Destroy the cublas/hipblas handle. Do this when the software is about to end.

cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name); // Translate a normal transpose parameter to a cublas/hipblas type.

cublasSideMode_t judge_side(const char& trans); // Translate a normal side parameter to a cublas/hipblas type.

cublasFillMode_t judge_fill(const char& trans); // Translate a normal fill parameter to a cublas/hipblas type.

}

#endif
Expand Down
77 changes: 77 additions & 0 deletions source/module_base/blas_connector_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "blas_connector.h"
#include "macros.h"

#ifdef __CUDA
#include <base/macros/macros.h>
#include <cuda_runtime.h>
#include "cublas_v2.h"
#include "module_base/kernels/math_kernel_op.h"
#include "module_base/module_device/memory_op.h"


namespace BlasUtils{

void createGpuBlasHandle(){
if (cublas_handle == nullptr) {
cublasErrcheck(cublasCreate(&cublas_handle));
}
}

void destoryBLAShandle(){
if (cublas_handle != nullptr) {
cublasErrcheck(cublasDestroy(cublas_handle));
cublas_handle = nullptr;
}
}


cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
return CUBLAS_OP_N;
}

cublasSideMode_t judge_side(const char& trans)
{
if (trans == 'L')
{
return CUBLAS_SIDE_LEFT;
}
else if (trans == 'R')
{
return CUBLAS_SIDE_RIGHT;
}
return CUBLAS_SIDE_LEFT;
}

cublasFillMode_t judge_fill(const char& trans)
{
if (trans == 'F')
{
return CUBLAS_FILL_MODE_FULL;
}
else if (trans == 'U')
{
return CUBLAS_FILL_MODE_UPPER;
}
else if (trans == 'D')
{
return CUBLAS_FILL_MODE_LOWER;
}
return CUBLAS_FILL_MODE_FULL;
}

} // namespace BlasUtils

#endif
Loading
Loading