Skip to content

Commit 0461869

Browse files
committed
FIXFIX
1 parent edce0c3 commit 0461869

File tree

8 files changed

+64
-63
lines changed

8 files changed

+64
-63
lines changed

source/module_base/blas_connector.cpp

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#include "cublas_v2.h"
1717

18+
namespace BlasUtils{
19+
1820
static cublasHandle_t cublas_handle = nullptr;
1921

2022
void createGpuBlasHandle(){
@@ -30,7 +32,9 @@ void destoryBLAShandle(){
3032
}
3133
}
3234

33-
cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
35+
} // namespace BlasUtils
36+
37+
cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
3438
{
3539
if (trans == 'N')
3640
{
@@ -44,10 +48,7 @@ cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char*
4448
{
4549
return CUBLAS_OP_C;
4650
}
47-
else
48-
{
49-
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
50-
}
51+
return CUBLAS_OP_N;
5152
}
5253

5354
#endif
@@ -59,7 +60,7 @@ void BlasConnector::axpy( const int n, const float alpha, const float *X, const
5960
}
6061
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
6162
#ifdef __CUDA
62-
cublasErrcheck(cublasSaxpy(cublas_handle, n, &alpha, X, incX, Y, incY));
63+
cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
6364
#endif
6465
}
6566
}
@@ -71,7 +72,7 @@ void BlasConnector::axpy( const int n, const double alpha, const double *X, cons
7172
}
7273
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
7374
#ifdef __CUDA
74-
cublasErrcheck(cublasDaxpy(cublas_handle, n, &alpha, X, incX, Y, incY));
75+
cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
7576
#endif
7677
}
7778
}
@@ -83,7 +84,7 @@ void BlasConnector::axpy( const int n, const std::complex<float> alpha, const st
8384
}
8485
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
8586
#ifdef __CUDA
86-
cublasErrcheck(cublasCaxpy(cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
87+
cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
8788
#endif
8889
}
8990
}
@@ -95,7 +96,7 @@ void BlasConnector::axpy( const int n, const std::complex<double> alpha, const s
9596
}
9697
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
9798
#ifdef __CUDA
98-
cublasErrcheck(cublasZaxpy(cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
99+
cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
99100
#endif
100101
}
101102
}
@@ -109,7 +110,7 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
109110
}
110111
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
111112
#ifdef __CUDA
112-
cublasErrcheck(cublasSscal(cublas_handle, n, &alpha, X, incX));
113+
cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
113114
#endif
114115
}
115116
}
@@ -121,7 +122,7 @@ void BlasConnector::scal( const int n, const double alpha, double *X, const int
121122
}
122123
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
123124
#ifdef __CUDA
124-
cublasErrcheck(cublasDscal(cublas_handle, n, &alpha, X, incX));
125+
cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
125126
#endif
126127
}
127128
}
@@ -133,7 +134,7 @@ void BlasConnector::scal( const int n, const std::complex<float> alpha, std::com
133134
}
134135
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
135136
#ifdef __CUDA
136-
cublasErrcheck(cublasCscal(cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
137+
cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
137138
#endif
138139
}
139140
}
@@ -145,7 +146,7 @@ void BlasConnector::scal( const int n, const std::complex<double> alpha, std::co
145146
}
146147
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
147148
#ifdef __CUDA
148-
cublasErrcheck(cublasZscal(cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
149+
cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
149150
#endif
150151
}
151152
}
@@ -160,7 +161,7 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
160161
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
161162
#ifdef __CUDA
162163
float result = 0.0;
163-
cublasErrcheck(cublasSdot(cublas_handle, n, X, incX, Y, incY, &result));
164+
cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
164165
return result;
165166
#endif
166167
}
@@ -175,7 +176,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
175176
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
176177
#ifdef __CUDA
177178
double result = 0.0;
178-
cublasErrcheck(cublasDdot(cublas_handle, n, X, incX, Y, incY, &result));
179+
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
179180
return result;
180181
#endif
181182
}
@@ -201,9 +202,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
201202
#endif
202203
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
203204
#ifdef __CUDA
204-
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
205-
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
206-
cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
205+
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
206+
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
207+
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
207208
#endif
208209
}
209210
}
@@ -226,9 +227,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
226227
#endif
227228
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
228229
#ifdef __CUDA
229-
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
230-
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
231-
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
230+
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
231+
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
232+
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
232233
#endif
233234
}
234235
}
@@ -251,9 +252,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
251252
#endif
252253
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
253254
#ifdef __CUDA
254-
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
255-
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
256-
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
255+
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
256+
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
257+
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
257258
#endif
258259
}
259260
}
@@ -276,9 +277,9 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
276277
#endif
277278
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
278279
#ifdef __CUDA
279-
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
280-
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
281-
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
280+
cublasOperation_t cutransA = judge_trans(false, transa, "gemm_op");
281+
cublasOperation_t cutransB = judge_trans(false, transb, "gemm_op");
282+
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
282283
#endif
283284
}
284285
}
@@ -292,8 +293,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
292293
}
293294
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
294295
#ifdef __CUDA
295-
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
296-
cublasErrcheck(cublasSgemv(cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
296+
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
297+
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
297298
#endif
298299
}
299300
}
@@ -307,8 +308,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
307308
}
308309
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
309310
#ifdef __CUDA
310-
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
311-
cublasErrcheck(cublasDgemv(cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
311+
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
312+
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
312313
#endif
313314
}
314315
}
@@ -322,8 +323,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
322323
}
323324
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
324325
#ifdef __CUDA
325-
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
326-
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incX, (float2*)&beta, (float2*)Y, incY));
326+
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
327+
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutrans, m, n, (float2*)&alpha, (float2*)A, lda, (float2*)X, incX, (float2*)&beta, (float2*)Y, incY));
327328
#endif
328329
}
329330
}
@@ -337,8 +338,8 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
337338
}
338339
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
339340
#ifdef __CUDA
340-
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
341-
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incX, (double2*)&beta, (double2*)Y, incY));
341+
cublasOperation_t cutrans = judge_trans(false, trans, "gemv_op");
342+
cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutrans, m, n, (double2*)&alpha, (double2*)A, lda, (double2*)X, incX, (double2*)&beta, (double2*)Y, incY));
342343
#endif
343344
}
344345
}

source/module_base/test/CMakeLists.txt

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ remove_definitions(-D__MPI)
22
install(DIRECTORY data DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
33
AddTest(
44
TARGET base_blas_connector
5-
LIBS parameter ${math_libs}
5+
LIBS parameter ${math_libs} device
66
SOURCES blas_connector_test.cpp ../blas_connector.cpp
77
)
88
AddTest(
@@ -31,7 +31,7 @@ AddTest(
3131
)
3232
ADDTest(
3333
TARGET base_global_function
34-
LIBS parameter ${math_libs}
34+
LIBS parameter ${math_libs} base device
3535
SOURCES global_function_test.cpp ../blas_connector.cpp ../global_function.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp
3636
)
3737
AddTest(
@@ -41,7 +41,7 @@ AddTest(
4141
)
4242
AddTest(
4343
TARGET base_matrix3
44-
LIBS parameter ${math_libs}
44+
LIBS parameter ${math_libs} device
4545
SOURCES matrix3_test.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp ../blas_connector.cpp
4646
)
4747
AddTest(
@@ -56,7 +56,7 @@ AddTest(
5656
)
5757
AddTest(
5858
TARGET base_matrix
59-
LIBS parameter ${math_libs}
59+
LIBS parameter ${math_libs} device
6060
SOURCES matrix_test.cpp ../blas_connector.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp
6161
)
6262
AddTest(
@@ -66,7 +66,7 @@ AddTest(
6666
)
6767
AddTest(
6868
TARGET base_complexmatrix
69-
LIBS parameter ${math_libs}
69+
LIBS parameter ${math_libs} device
7070
SOURCES complexmatrix_test.cpp ../blas_connector.cpp ../complexmatrix.cpp ../matrix.cpp
7171
)
7272
AddTest(
@@ -93,12 +93,12 @@ AddTest(
9393
)
9494
AddTest(
9595
TARGET base_mathzone
96-
LIBS parameter ${math_libs}
96+
LIBS parameter ${math_libs} device
9797
SOURCES mathzone_test.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp ../blas_connector.cpp
9898
)
9999
AddTest(
100100
TARGET base_mathzone_add1
101-
LIBS parameter ${math_libs}
101+
LIBS parameter ${math_libs} device
102102
SOURCES mathzone_add1_test.cpp ../blas_connector.cpp ../mathzone_add1.cpp ../math_sphbes.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp
103103
)
104104
AddTest(
@@ -108,7 +108,7 @@ AddTest(
108108
)
109109
AddTest(
110110
TARGET base_gram_schmidt_orth
111-
LIBS parameter ${math_libs}
111+
LIBS parameter ${math_libs} device
112112
SOURCES gram_schmidt_orth_test.cpp ../blas_connector.cpp ../gram_schmidt_orth.h ../gram_schmidt_orth-inl.h ../global_function.h ../math_integral.cpp
113113
)
114114
AddTest(
@@ -118,7 +118,7 @@ AddTest(
118118
)
119119
AddTest(
120120
TARGET base_inverse_matrix
121-
LIBS parameter ${math_libs}
121+
LIBS parameter ${math_libs} device
122122
SOURCES inverse_matrix_test.cpp ../blas_connector.cpp ../inverse_matrix.cpp ../complexmatrix.cpp ../matrix.cpp ../timer.cpp
123123
)
124124
AddTest(
@@ -140,19 +140,19 @@ AddTest(
140140

141141
AddTest(
142142
TARGET base_lapack_connector
143-
LIBS parameter ${math_libs}
143+
LIBS parameter ${math_libs} device
144144
SOURCES lapack_connector_test.cpp ../blas_connector.cpp ../lapack_connector.h
145145
)
146146

147147
AddTest(
148148
TARGET base_opt_CG
149-
LIBS parameter ${math_libs}
149+
LIBS parameter ${math_libs} device
150150
SOURCES opt_CG_test.cpp opt_test_tools.cpp ../blas_connector.cpp ../opt_CG.cpp ../opt_DCsrch.cpp ../global_variable.cpp ../parallel_reduce.cpp
151151
)
152152

153153
AddTest(
154154
TARGET base_opt_TN
155-
LIBS parameter ${math_libs}
155+
LIBS parameter ${math_libs} device
156156
SOURCES opt_TN_test.cpp opt_test_tools.cpp ../blas_connector.cpp ../opt_CG.cpp ../opt_DCsrch.cpp ../global_variable.cpp ../parallel_reduce.cpp
157157
)
158158

@@ -195,13 +195,13 @@ AddTest(
195195
AddTest(
196196
TARGET spherical_bessel_transformer
197197
SOURCES spherical_bessel_transformer_test.cpp ../blas_connector.cpp ../spherical_bessel_transformer.cpp ../math_sphbes.cpp ../math_integral.cpp ../timer.cpp
198-
LIBS parameter ${math_libs}
198+
LIBS parameter ${math_libs} base device
199199
)
200200

201201
AddTest(
202202
TARGET cubic_spline
203203
SOURCES cubic_spline_test.cpp ../blas_connector.cpp ../cubic_spline.cpp
204-
LIBS parameter ${math_libs}
204+
LIBS parameter ${math_libs} device
205205
)
206206

207207
AddTest(
@@ -215,7 +215,7 @@ AddTest(
215215
AddTest(
216216
TARGET assoc_laguerre_test
217217
SOURCES assoc_laguerre_test.cpp ../blas_connector.cpp ../assoc_laguerre.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp
218-
LIBS parameter ${math_libs}
218+
LIBS parameter ${math_libs} device
219219
)
220220

221221
AddTest(

source/module_basis/module_pw/kernels/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ add_definitions(-D__NORMAL)
22

33
AddTest(
44
TARGET PW_Kernels_UTs
5-
LIBS parameter ${math_libs} psi device
5+
LIBS parameter ${math_libs} base psi device
66
SOURCES pw_op_test.cpp
77
../../../../module_base/tool_quit.cpp ../../../../module_base/global_variable.cpp
88
../../../../module_base/parallel_global.cpp ../../../../module_base/parallel_reduce.cpp

source/module_basis/module_pw/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
add_definitions(-D__NORMAL)
22
AddTest(
33
TARGET pw_test
4-
LIBS parameter ${math_libs} planewave device
4+
LIBS parameter ${math_libs} planewave base device
55
SOURCES ../../../module_base/matrix.cpp ../../../module_base/complexmatrix.cpp ../../../module_base/matrix3.cpp ../../../module_base/tool_quit.cpp
66
../../../module_base/mymath.cpp ../../../module_base/timer.cpp ../../../module_base/memory.cpp ../../../module_base/blas_connector.cpp
77
../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp

source/module_hamilt_general/module_xc/test/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ list(APPEND FFT_SRC ../../../module_basis/module_pw/module_fft/fft_rocm.cpp)
2929
endif()
3030
AddTest(
3131
TARGET XCTest_GRADCORR
32-
LIBS parameter MPI::MPI_CXX Libxc::xc ${math_libs} psi device container
32+
LIBS parameter MPI::MPI_CXX Libxc::xc ${math_libs} psi base device container
3333
SOURCES test_xc3.cpp ../xc_functional_gradcorr.cpp ../xc_functional.cpp
3434
../xc_functional_wrapper_xc.cpp ../xc_functional_wrapper_gcxc.cpp
3535
../xc_functional_libxc.cpp
@@ -63,7 +63,7 @@ AddTest(
6363

6464
AddTest(
6565
TARGET XCTest_VXC
66-
LIBS parameter MPI::MPI_CXX Libxc::xc ${math_libs} psi device container
66+
LIBS parameter MPI::MPI_CXX Libxc::xc ${math_libs} psi base device container
6767
SOURCES test_xc5.cpp ../xc_functional_gradcorr.cpp ../xc_functional.cpp
6868
../xc_functional_wrapper_xc.cpp ../xc_functional_wrapper_gcxc.cpp
6969
../xc_functional_libxc.cpp

source/module_hamilt_pw/hamilt_pwdft/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ remove_definitions(-D__EXX)
55

66
AddTest(
77
TARGET pwdft_soc
8-
LIBS parameter ${math_libs}
8+
LIBS parameter ${math_libs} base device
99
SOURCES soc_test.cpp ../soc.cpp
1010
../../../module_base/global_variable.cpp
1111
../../../module_base/global_function.cpp

0 commit comments

Comments
 (0)