Skip to content

Commit 7da86cd

Browse files
authored
Feature:Now sDFT support GPU (#5423)
* make chebyshev support GPU * make part of stoiter support GPU * make scf support GPU * add in GPU class implement * fix GPU bug * fix compile * fix bug * fix compile * try fix pytest
1 parent 83efe85 commit 7da86cd

37 files changed

+1014
-355
lines changed

python/pyabacus/src/ModuleBase/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ list(APPEND pymodule_base
22
${PROJECT_SOURCE_DIR}/src/ModuleBase/py_base_math.cpp
33
${BASE_PATH}/kernels/math_op.cpp
44
${BASE_PATH}/module_device/memory_op.cpp
5+
${BASE_PATH}/module_device/device.cpp
56
)
67

78
pybind11_add_module(_base_pack MODULE ${pymodule_base})

python/pyabacus/src/ModuleNAO/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ list(APPEND _naos
1717
${ABACUS_SOURCE_DIR}/module_base/kernels/math_op.cpp
1818
# ${ABACUS_SOURCE_DIR}/module_psi/kernels/psi_memory_op.cpp
1919
${ABACUS_SOURCE_DIR}/module_base/module_device/memory_op.cpp
20+
${ABACUS_SOURCE_DIR}/module_base/module_device/device.cpp
2021
)
2122
add_library(naopack SHARED
2223
${_naos}

source/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ list(APPEND device_srcs
3030
module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp
3131
module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp
3232
module_hamilt_pw/hamilt_pwdft/kernels/meta_op.cpp
33+
module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp
3334
module_basis/module_pw/kernels/pw_op.cpp
3435
module_hsolver/kernels/dngvd_op.cpp
3536
module_hsolver/kernels/math_kernel_op.cpp
@@ -55,6 +56,7 @@ if(USE_CUDA)
5556
module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu
5657
module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu
5758
module_hamilt_pw/hamilt_pwdft/kernels/cuda/meta_op.cu
59+
module_hamilt_pw/hamilt_stodft/kernels/cuda/hpsi_norm_op.cu
5860
module_basis/module_pw/kernels/cuda/pw_op.cu
5961
module_hsolver/kernels/cuda/dngvd_op.cu
6062
module_hsolver/kernels/cuda/math_kernel_op.cu
@@ -78,6 +80,7 @@ if(USE_ROCM)
7880
module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu
7981
module_hamilt_pw/hamilt_pwdft/kernels/rocm/ekinetic_op.hip.cu
8082
module_hamilt_pw/hamilt_pwdft/kernels/rocm/meta_op.hip.cu
83+
module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu
8184
module_basis/module_pw/kernels/rocm/pw_op.hip.cu
8285
module_hsolver/kernels/rocm/dngvd_op.hip.cu
8386
module_hsolver/kernels/rocm/math_kernel_op.hip.cu

source/Makefile.Objects

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ VPATH=./src_global:\
4949
./module_hamilt_pw/hamilt_stodft:\
5050
./module_hamilt_pw/hamilt_pwdft/operator_pw:\
5151
./module_hamilt_pw/hamilt_pwdft/kernels:\
52+
./module_hamilt_pw/hamilt_stodft/kernels:\
5253
./module_hamilt_lcao/module_hcontainer:\
5354
./module_hamilt_lcao/hamilt_lcaodft:\
5455
./module_hamilt_lcao/module_tddft:\
@@ -296,6 +297,7 @@ OBJS_HAMILT=hamilt_pw.o\
296297
operator_pw.o\
297298
ekinetic_pw.o\
298299
ekinetic_op.o\
300+
hpsi_norm_op.o\
299301
veff_pw.o\
300302
veff_op.o\
301303
nonlocal_pw.o\

source/module_base/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ add_library(
6161
${LIBM_SRC}
6262
)
6363

64+
target_link_libraries(base PUBLIC container)
65+
6466
add_subdirectory(module_container)
6567

6668
if(ENABLE_COVERAGE)

source/module_base/blas_connector.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
154154
#endif
155155
}
156156

157+
void BlasConnector::gemv(const char trans, const int m, const int n,
158+
const float alpha, const float* A, const int lda, const float* X, const int incx,
159+
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type)
160+
{
161+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162+
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
163+
}
164+
}
165+
157166
void BlasConnector::gemv(const char trans, const int m, const int n,
158167
const double alpha, const double* A, const int lda, const double* X, const int incx,
159168
const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type)

source/module_base/blas_connector.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ extern "C"
4040
double dznrm2_( const int *n, const std::complex<double> *X, const int *incX );
4141

4242
// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
43+
void sgemv_(const char*const transa, const int*const m, const int*const n,
44+
const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx,
45+
const float*const beta, float*const y, const int*const incy);
4346
void dgemv_(const char*const transa, const int*const m, const int*const n,
4447
const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx,
4548
const double*const beta, double*const y, const int*const incy);
@@ -178,6 +181,11 @@ class BlasConnector
178181
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
179182
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
180183

184+
static
185+
void gemv(const char trans, const int m, const int n,
186+
const float alpha, const float* A, const int lda, const float* X, const int incx,
187+
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
188+
181189
static
182190
void gemv(const char trans, const int m, const int n,
183191
const double alpha, const double* A, const int lda, const double* X, const int incx,

0 commit comments

Comments
 (0)