diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp
index 85ea4584e9..3bb91e2f01 100644
--- a/source/module_base/blas_connector.cpp
+++ b/source/module_base/blas_connector.cpp
@@ -5,32 +5,101 @@
#include "module_base/global_variable.h"
#endif
+#ifdef __CUDA
+#include
+#include
+#include
+#include
+#include
+#include "module_base/tool_quit.h"
+
+#include "cublas_v2.h"
+
+namespace BlasUtils{
+
+ static cublasHandle_t cublas_handle = nullptr;
+
+ void createGpuBlasHandle(){
+ if (cublas_handle == nullptr) {
+ cublasErrcheck(cublasCreate(&cublas_handle));
+ }
+ }
+
+ void destoryBLAShandle(){
+ if (cublas_handle != nullptr) {
+ cublasErrcheck(cublasDestroy(cublas_handle));
+ cublas_handle = nullptr;
+ }
+ }
+
+
+ cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
+ {
+ if (trans == 'N')
+ {
+ return CUBLAS_OP_N;
+ }
+ else if(trans == 'T')
+ {
+ return CUBLAS_OP_T;
+ }
+ else if(is_complex && trans == 'C')
+ {
+ return CUBLAS_OP_C;
+ }
+ return CUBLAS_OP_N;
+ }
+
+} // namespace BlasUtils
+
+#endif
+
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
saxpy_(&n, &alpha, X, &incX, Y, &incY);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
+#endif
+ }
}
void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
daxpy_(&n, &alpha, X, &incX, Y, &incY);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
+#endif
+ }
}
void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
caxpy_(&n, &alpha, X, &incX, Y, &incY);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
+#endif
+ }
}
void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
+#endif
+ }
}
@@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sscal_(&n, &alpha, X, &incX);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
+#ifdef __CUDA
+ cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
+#endif
+ }
}
void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dscal_(&n, &alpha, X, &incX);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
+#ifdef __CUDA
+ cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
+#endif
+ }
}
void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cscal_(&n, &alpha, X, &incX);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
+#ifdef __CUDA
+ cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
+#endif
+ }
}
void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zscal_(&n, &alpha, X, &incX);
-}
+ }
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
+#ifdef __CUDA
+ cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
+#endif
+ }
}
@@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
}
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ float result = 0.0;
+ cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
+ return result;
+#endif
+ }
return sdot_(&n, X, &incX, Y, &incY);
}
@@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
}
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ double result = 0.0;
+ cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
+ return result;
+#endif
+ }
return ddot_(&n, X, &incX, Y, &incY);
}
@@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
+#endif
+ }
}
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
+#endif
+ }
}
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
+#endif
+ }
}
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
+#endif
+ }
}
// Col-Major part
@@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
+#endif
+ }
}
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
@@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
+#endif
+ }
}
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
@@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
+#endif
+ }
}
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
@@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
- #ifdef __DSP
+#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
- #endif
+#endif
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
+#ifdef __CUDA
+ cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
+ cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
+ cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
+#endif
+ }
}
// Symm and Hemm part. Only col-major is supported.
diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt
index 09b77c7404..0c8fd53461 100644
--- a/source/module_base/test/CMakeLists.txt
+++ b/source/module_base/test/CMakeLists.txt
@@ -2,8 +2,8 @@ remove_definitions(-D__MPI)
install(DIRECTORY data DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
AddTest(
TARGET base_blas_connector
- LIBS parameter ${math_libs}
- SOURCES blas_connector_test.cpp ../blas_connector.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES blas_connector_test.cpp
)
AddTest(
TARGET base_atom_in
@@ -31,8 +31,8 @@ AddTest(
)
ADDTest(
TARGET base_global_function
- LIBS parameter ${math_libs}
- SOURCES global_function_test.cpp ../blas_connector.cpp ../global_function.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp
+ LIBS parameter ${math_libs}
+ SOURCES global_function_test.cpp ../global_function.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp
)
AddTest(
TARGET base_vector3
@@ -41,8 +41,8 @@ AddTest(
)
AddTest(
TARGET base_matrix3
- LIBS parameter ${math_libs}
- SOURCES matrix3_test.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp ../blas_connector.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES matrix3_test.cpp
)
AddTest(
TARGET base_intarray
@@ -56,8 +56,8 @@ AddTest(
)
AddTest(
TARGET base_matrix
- LIBS parameter ${math_libs}
- SOURCES matrix_test.cpp ../blas_connector.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES matrix_test.cpp
)
AddTest(
TARGET base_complexarray
@@ -66,8 +66,8 @@ AddTest(
)
AddTest(
TARGET base_complexmatrix
- LIBS parameter ${math_libs}
- SOURCES complexmatrix_test.cpp ../blas_connector.cpp ../complexmatrix.cpp ../matrix.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES complexmatrix_test.cpp
)
AddTest(
TARGET base_integral
@@ -81,10 +81,8 @@ AddTest(
)
AddTest(
TARGET base_ylmreal
- LIBS parameter ${math_libs} device
- SOURCES math_ylmreal_test.cpp ../blas_connector.cpp ../math_ylmreal.cpp ../complexmatrix.cpp ../global_variable.cpp ../ylm.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ../vector3.h
- ../parallel_reduce.cpp ../parallel_global.cpp ../parallel_comm.cpp ../parallel_common.cpp
- ../memory.cpp ../libm/branred.cpp ../libm/sincos.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES math_ylmreal_test.cpp ../libm/branred.cpp ../libm/sincos.cpp
)
AddTest(
TARGET base_math_sphbes
@@ -93,13 +91,13 @@ AddTest(
)
AddTest(
TARGET base_mathzone
- LIBS parameter ${math_libs}
- SOURCES mathzone_test.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp ../blas_connector.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES mathzone_test.cpp
)
AddTest(
TARGET base_mathzone_add1
- LIBS parameter ${math_libs}
- SOURCES mathzone_add1_test.cpp ../blas_connector.cpp ../mathzone_add1.cpp ../math_sphbes.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES mathzone_add1_test.cpp
)
AddTest(
TARGET base_math_polyint
@@ -108,8 +106,8 @@ AddTest(
)
AddTest(
TARGET base_gram_schmidt_orth
- LIBS parameter ${math_libs}
- SOURCES gram_schmidt_orth_test.cpp ../blas_connector.cpp ../gram_schmidt_orth.h ../gram_schmidt_orth-inl.h ../global_function.h ../math_integral.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES gram_schmidt_orth_test.cpp
)
AddTest(
TARGET base_math_bspline
@@ -118,8 +116,8 @@ AddTest(
)
AddTest(
TARGET base_inverse_matrix
- LIBS parameter ${math_libs}
- SOURCES inverse_matrix_test.cpp ../blas_connector.cpp ../inverse_matrix.cpp ../complexmatrix.cpp ../matrix.cpp ../timer.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES inverse_matrix_test.cpp
)
AddTest(
TARGET base_mymath
@@ -134,26 +132,26 @@ AddTest(
AddTest(
TARGET base_math_chebyshev
- LIBS parameter ${math_libs} device container
- SOURCES math_chebyshev_test.cpp ../blas_connector.cpp ../math_chebyshev.cpp ../tool_quit.cpp ../global_variable.cpp ../timer.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../parallel_reduce.cpp
+ LIBS parameter ${math_libs} base device container
+ SOURCES math_chebyshev_test.cpp
)
AddTest(
TARGET base_lapack_connector
- LIBS parameter ${math_libs}
- SOURCES lapack_connector_test.cpp ../blas_connector.cpp ../lapack_connector.h
+ LIBS parameter ${math_libs} base device
+ SOURCES lapack_connector_test.cpp
)
AddTest(
TARGET base_opt_CG
- LIBS parameter ${math_libs}
- SOURCES opt_CG_test.cpp opt_test_tools.cpp ../blas_connector.cpp ../opt_CG.cpp ../opt_DCsrch.cpp ../global_variable.cpp ../parallel_reduce.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES opt_CG_test.cpp opt_test_tools.cpp
)
AddTest(
TARGET base_opt_TN
- LIBS parameter ${math_libs}
- SOURCES opt_TN_test.cpp opt_test_tools.cpp ../blas_connector.cpp ../opt_CG.cpp ../opt_DCsrch.cpp ../global_variable.cpp ../parallel_reduce.cpp
+ LIBS parameter ${math_libs} base device
+ SOURCES opt_TN_test.cpp opt_test_tools.cpp
)
AddTest(
@@ -194,28 +192,26 @@ AddTest(
AddTest(
TARGET spherical_bessel_transformer
- SOURCES spherical_bessel_transformer_test.cpp ../blas_connector.cpp ../spherical_bessel_transformer.cpp ../math_sphbes.cpp ../math_integral.cpp ../timer.cpp
- LIBS parameter ${math_libs}
+ SOURCES spherical_bessel_transformer_test.cpp
+ LIBS parameter ${math_libs} base device
)
AddTest(
TARGET cubic_spline
- SOURCES cubic_spline_test.cpp ../blas_connector.cpp ../cubic_spline.cpp
- LIBS parameter ${math_libs}
+ SOURCES cubic_spline_test.cpp
+ LIBS parameter ${math_libs} base device
)
AddTest(
TARGET clebsch_gordan_coeff_test
- SOURCES clebsch_gordan_coeff_test.cpp ../blas_connector.cpp ../clebsch_gordan_coeff.cpp ../intarray.cpp ../realarray.cpp ../complexmatrix.cpp ../matrix.cpp ../timer.cpp
- ../math_ylmreal.cpp ../global_variable.cpp ../ylm.cpp ../timer.cpp ../vector3.h ../parallel_reduce.cpp ../parallel_global.cpp ../parallel_comm.cpp ../parallel_common.cpp
- ../memory.cpp ../libm/branred.cpp ../libm/sincos.cpp ../inverse_matrix.cpp ../lapack_connector.h
- LIBS parameter ${math_libs} device
+ SOURCES clebsch_gordan_coeff_test.cpp
+ LIBS parameter ${math_libs} base device
)
AddTest(
TARGET assoc_laguerre_test
- SOURCES assoc_laguerre_test.cpp ../blas_connector.cpp ../assoc_laguerre.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp
- LIBS parameter ${math_libs}
+ SOURCES assoc_laguerre_test.cpp
+ LIBS parameter ${math_libs} base device
)
AddTest(
diff --git a/source/module_base/test/clebsch_gordan_coeff_test.cpp b/source/module_base/test/clebsch_gordan_coeff_test.cpp
index 16efa091b5..888249765f 100644
--- a/source/module_base/test/clebsch_gordan_coeff_test.cpp
+++ b/source/module_base/test/clebsch_gordan_coeff_test.cpp
@@ -16,18 +16,6 @@
* - functions: gen_rndm_r and compute_ap
*/
-namespace ModuleBase
-{
-void WARNING_QUIT(const std::string& file, const std::string& description)
-{
- return;
-}
-void WARNING(const std::string& file, const std::string& description)
-{
- return;
-}
-} // namespace ModuleBase
-
TEST(ClebschGordanTest, ClebschGordanExit)
{
int lmaxkb = -2;
diff --git a/source/module_base/test/complexmatrix_test.cpp b/source/module_base/test/complexmatrix_test.cpp
index 0adc52363a..da11fafcfd 100644
--- a/source/module_base/test/complexmatrix_test.cpp
+++ b/source/module_base/test/complexmatrix_test.cpp
@@ -38,12 +38,6 @@
*
*/
-//a mock function of WARNING_QUIT, to avoid the uncorrected call by matrix.cpp at line 37.
-namespace ModuleBase
-{
- void WARNING_QUIT(const std::string &file,const std::string &description) {exit(1);}
-}
-
inline void EXPECT_COMPLEX_EQ(const std::complex& a,const std::complex& b)
{
EXPECT_DOUBLE_EQ(a.real(),b.real());
diff --git a/source/module_base/test/global_function_test.cpp b/source/module_base/test/global_function_test.cpp
index 013396d6b1..05d4d70877 100644
--- a/source/module_base/test/global_function_test.cpp
+++ b/source/module_base/test/global_function_test.cpp
@@ -4,7 +4,6 @@
#include "module_parameter/parameter.h"
#undef private
#include "../vector3.h"
-#include "../blas_connector.h"
#include "../tool_quit.h"
#include
#include
@@ -692,6 +691,9 @@ TEST_F(GlobalFunctionTest,MemAvailable)
TEST_F(GlobalFunctionTest,BlockHere)
{
+#ifdef __MPI
+#undef __MPI
+#endif
std::string output2;
std::string block_in="111";
GlobalV::MY_RANK=1;
@@ -706,6 +708,9 @@ TEST_F(GlobalFunctionTest,BlockHere)
TEST_F(GlobalFunctionTest,BlockHere2)
{
+#ifdef __MPI
+#undef __MPI
+#endif
std::string output2;
std::string block_in="111";
GlobalV::MY_RANK=0;
@@ -724,6 +729,9 @@ TEST_F(GlobalFunctionTest,BlockHere2)
TEST_F(GlobalFunctionTest,BlockHere3)
{
+#ifdef __MPI
+#undef __MPI
+#endif
std::string output2;
std::string block_in="111";
GlobalV::MY_RANK=0;
diff --git a/source/module_base/test/inverse_matrix_test.cpp b/source/module_base/test/inverse_matrix_test.cpp
index a871f906cd..b88e556af1 100644
--- a/source/module_base/test/inverse_matrix_test.cpp
+++ b/source/module_base/test/inverse_matrix_test.cpp
@@ -16,12 +16,6 @@
* - computes the inverse of a dim*dim real matrix
*/
-//a mock function of WARNING_QUIT, to avoid the uncorrected call by matrix.cpp at line 37.
-namespace ModuleBase
-{
- void WARNING_QUIT(const std::string &file,const std::string &description) {exit(1);}
-}
-
TEST(InverseMatrixComplexTest, InverseMatrixComplex)
{
int dim = 10;
diff --git a/source/module_base/test/math_chebyshev_test.cpp b/source/module_base/test/math_chebyshev_test.cpp
index 125dbdaeaa..a7ea215266 100644
--- a/source/module_base/test/math_chebyshev_test.cpp
+++ b/source/module_base/test/math_chebyshev_test.cpp
@@ -336,6 +336,8 @@ TEST_F(MathChebyshevTest, tracepolyA)
TEST_F(MathChebyshevTest, checkconverge)
{
+#ifdef __MPI
+#undef __MPI
const int norder = 100;
p_chetest = new ModuleBase::Chebyshev(norder);
auto fun_sigma_y
@@ -377,6 +379,8 @@ TEST_F(MathChebyshevTest, checkconverge)
delete[] v;
delete p_chetest;
+#define __MPI
+#endif
}
TEST_F(MathChebyshevTest, recurs)
diff --git a/source/module_base/test/math_ylmreal_test.cpp b/source/module_base/test/math_ylmreal_test.cpp
index c973d8cd28..891c948f7e 100644
--- a/source/module_base/test/math_ylmreal_test.cpp
+++ b/source/module_base/test/math_ylmreal_test.cpp
@@ -36,16 +36,6 @@
*
*/
-
-
-//mock functions of WARNING_QUIT and WARNING
-namespace ModuleBase
-{
- void WARNING_QUIT(const std::string &file,const std::string &description) {exit(1);}
- void WARNING(const std::string &file,const std::string &description) {return ;}
-}
-
-
class YlmRealTest : public testing::Test
{
protected:
diff --git a/source/module_base/test/opt_CG_test.cpp b/source/module_base/test/opt_CG_test.cpp
index 4b324c7cbb..b8abeb5760 100644
--- a/source/module_base/test/opt_CG_test.cpp
+++ b/source/module_base/test/opt_CG_test.cpp
@@ -1,3 +1,6 @@
+#ifdef __MPI
+#undef __MPI
+#endif
#include "gtest/gtest.h"
#include "../opt_CG.h"
#include "../opt_DCsrch.h"
@@ -18,10 +21,10 @@ class CG_test : public testing::Test
double residual = 10.;
double tol = 1e-5;
int final_iter = 0;
- char *task = NULL;
- double *Ap = NULL;
- double *p = NULL;
- double *x = NULL;
+ char *task = nullptr;
+ double *Ap = nullptr;
+ double *p = nullptr;
+ double *x = nullptr;
void SetUp()
{
@@ -65,7 +68,8 @@ class CG_test : public testing::Test
tools.le.get_Ap(tools.le.A, p, Ap);
int ifPD = 0;
step = cg.step_length(Ap, p, ifPD);
- for (int i = 0; i < 3; ++i) x[i] += step * p[i];
+ for (int i = 0; i < 3; ++i) { x[i] += step * p[i];
+}
residual = cg.get_residual();
}
}
@@ -102,14 +106,16 @@ class CG_test : public testing::Test
{
tools.dfuncdx(x, gradient, func_label);
residual = 0;
- for (int i = 0; i<3 ;++i) residual += gradient[i] * gradient[i];
+ for (int i = 0; i<3 ;++i) { residual += gradient[i] * gradient[i];
+}
if (residual < tol)
{
final_iter = iter;
break;
}
cg.next_direct(gradient, cg_label, p);
- for (int i = 0; i < 3; ++i) temp_x[i] = x[i];
+ for (int i = 0; i < 3; ++i) { temp_x[i] = x[i];
+}
task[0] = 'S'; task[1] = 'T'; task[2] = 'A'; task[3] = 'R'; task[4] = 'T';
while (true)
{
@@ -118,7 +124,8 @@ class CG_test : public testing::Test
ds.dcSrch(f, g, step, task);
if (task[0] == 'F' && task[1] == 'G')
{
- for (int j = 0; j < 3; ++j) temp_x[j] = x[j] + step * p[j];
+ for (int j = 0; j < 3; ++j) { temp_x[j] = x[j] + step * p[j];
+}
continue;
}
else if (task[0] == 'C' && task[1] == 'O')
@@ -134,7 +141,8 @@ class CG_test : public testing::Test
break;
}
}
- for (int i = 0; i < 3; ++i) x[i] += step * p[i];
+ for (int i = 0; i < 3; ++i) { x[i] += step * p[i];
+}
}
delete[] temp_x;
delete[] gradient;
@@ -143,51 +151,71 @@ class CG_test : public testing::Test
TEST_F(CG_test, Stand_Solve_LinearEq)
{
+#ifdef __MPI
+#undef __MPI
CG_Solve_LinearEq();
EXPECT_NEAR(x[0], 0.5, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], 1.6429086563584579739e-18, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 1.5, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 4);
ASSERT_EQ(cg.get_iter(), 4);
+#define __MPI
+#endif
}
TEST_F(CG_test, PR_Solve_LinearEq)
{
+#ifdef __MPI
+#undef __MPI
Solve(1, 0);
EXPECT_NEAR(x[0], 0.50000000000003430589, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], -3.4028335704761047964e-14, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 1.5000000000000166533, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 3);
ASSERT_EQ(cg.get_iter(), 3);
+#define __MPI
+#endif
}
TEST_F(CG_test, HZ_Solve_LinearEq)
{
+#ifdef __MPI
+#undef __MPI
Solve(2, 0);
EXPECT_NEAR(x[0], 0.49999999999999944489, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], -9.4368957093138305936e-16, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 1.5000000000000011102, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 3);
ASSERT_EQ(cg.get_iter(), 3);
+#define __MPI
+#endif
}
TEST_F(CG_test, PR_Min_Func)
{
+#ifdef __MPI
+#undef __MPI
Solve(1, 1);
EXPECT_NEAR(x[0], 4.0006805979150792396, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], 2.0713759992720870429, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 9.2871067233169171118, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 18);
ASSERT_EQ(cg.get_iter(), 18);
+#define __MPI
+#endif
}
TEST_F(CG_test, HZ_Min_Func)
{
+#ifdef __MPI
+#undef __MPI
Solve(2, 1);
EXPECT_NEAR(x[0], 4.0006825378033568086, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], 2.0691732100663737803, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 9.2780872787668311474, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 18);
ASSERT_EQ(cg.get_iter(), 18);
+#define __MPI
+#endif
}
// g++ -std=c++11 ../opt_CG.cpp ../opt_DCsrch.cpp ./CG_test.cpp ./test_tools.cpp -lgtest -lpthread -lgtest_main -o test.exe
\ No newline at end of file
diff --git a/source/module_base/test/opt_TN_test.cpp b/source/module_base/test/opt_TN_test.cpp
index db523b53e9..1fc5b7f2d6 100644
--- a/source/module_base/test/opt_TN_test.cpp
+++ b/source/module_base/test/opt_TN_test.cpp
@@ -17,9 +17,9 @@ class TN_test : public testing::Test
double tol = 1e-5;
int final_iter = 0;
int flag = 0;
- char *task = NULL;
- double *p = NULL;
- double *x = NULL;
+ char *task = nullptr;
+ double *p = nullptr;
+ double *x = nullptr;
void SetUp()
{
@@ -61,7 +61,8 @@ class TN_test : public testing::Test
{
tools.dfuncdx(x, gradient, func_label);
residual = 0;
- for (int i = 0; i<3 ;++i) residual += gradient[i] * gradient[i];
+ for (int i = 0; i<3 ;++i) { residual += gradient[i] * gradient[i];
+}
if (residual < tol)
{
final_iter = iter;
@@ -75,7 +76,8 @@ class TN_test : public testing::Test
{
tn.next_direct(x, gradient, flag, p, &(tools.mf), &ModuleESolver::ESolver_OF::dfuncdx);
}
- for (int i = 0; i < 3; ++i) temp_x[i] = x[i];
+ for (int i = 0; i < 3; ++i) { temp_x[i] = x[i];
+}
task[0] = 'S'; task[1] = 'T'; task[2] = 'A'; task[3] = 'R'; task[4] = 'T';
while (true)
{
@@ -84,7 +86,8 @@ class TN_test : public testing::Test
ds.dcSrch(f, g, step, task);
if (task[0] == 'F' && task[1] == 'G')
{
- for (int j = 0; j < 3; ++j) temp_x[j] = x[j] + step * p[j];
+ for (int j = 0; j < 3; ++j) { temp_x[j] = x[j] + step * p[j];
+}
continue;
}
else if (task[0] == 'C' && task[1] == 'O')
@@ -100,7 +103,8 @@ class TN_test : public testing::Test
break;
}
}
- for (int i = 0; i < 3; ++i) x[i] += step * p[i];
+ for (int i = 0; i < 3; ++i) { x[i] += step * p[i];
+}
}
delete[] temp_x;
delete[] gradient;
@@ -110,20 +114,28 @@ class TN_test : public testing::Test
TEST_F(TN_test, TN_Solve_LinearEq)
{
+#ifdef __MPI
+#undef __MPI
Solve(0);
EXPECT_NEAR(x[0], 0.50000000000003430589, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], -3.4028335704761047964e-14, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 1.5000000000000166533, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 1);
ASSERT_EQ(tn.get_iter(), 1);
+#define __MPI
+#endif
}
TEST_F(TN_test, TN_Min_Func)
{
+#ifdef __MPI
+#undef __MPI
Solve(1);
EXPECT_NEAR(x[0], 4.0049968540891525137, DOUBLETHRESHOLD);
EXPECT_NEAR(x[1], 2.1208751163987624722, DOUBLETHRESHOLD);
EXPECT_NEAR(x[2], 9.4951527720891863993, DOUBLETHRESHOLD);
ASSERT_EQ(final_iter, 6);
ASSERT_EQ(tn.get_iter(), 6);
+#define __MPI
+#endif
}
\ No newline at end of file
diff --git a/source/module_base/test/opt_test_tools.cpp b/source/module_base/test/opt_test_tools.cpp
index 1c90b79bca..71e136b3ef 100644
--- a/source/module_base/test/opt_test_tools.cpp
+++ b/source/module_base/test/opt_test_tools.cpp
@@ -1,3 +1,6 @@
+#ifdef __MPI
+#undef __MPI
+#endif
#include "./opt_test_tools.h"
#include