diff --git a/python/pyabacus/src/ModuleBase/CMakeLists.txt b/python/pyabacus/src/ModuleBase/CMakeLists.txt index f150cf1b5e..7ce5fb5e3b 100644 --- a/python/pyabacus/src/ModuleBase/CMakeLists.txt +++ b/python/pyabacus/src/ModuleBase/CMakeLists.txt @@ -2,6 +2,7 @@ list(APPEND pymodule_base ${PROJECT_SOURCE_DIR}/src/ModuleBase/py_base_math.cpp ${BASE_PATH}/kernels/math_op.cpp ${BASE_PATH}/module_device/memory_op.cpp + ${BASE_PATH}/module_device/device.cpp ) pybind11_add_module(_base_pack MODULE ${pymodule_base}) diff --git a/python/pyabacus/src/ModuleNAO/CMakeLists.txt b/python/pyabacus/src/ModuleNAO/CMakeLists.txt index 65e209ca40..53600a08f3 100644 --- a/python/pyabacus/src/ModuleNAO/CMakeLists.txt +++ b/python/pyabacus/src/ModuleNAO/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND _naos ${ABACUS_SOURCE_DIR}/module_base/kernels/math_op.cpp # ${ABACUS_SOURCE_DIR}/module_psi/kernels/psi_memory_op.cpp ${ABACUS_SOURCE_DIR}/module_base/module_device/memory_op.cpp + ${ABACUS_SOURCE_DIR}/module_base/module_device/device.cpp ) add_library(naopack SHARED ${_naos} diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index d53cd1e63b..d57fd16337 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -30,6 +30,7 @@ list(APPEND device_srcs module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp module_hamilt_pw/hamilt_pwdft/kernels/meta_op.cpp + module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp module_basis/module_pw/kernels/pw_op.cpp module_hsolver/kernels/dngvd_op.cpp module_hsolver/kernels/math_kernel_op.cpp @@ -55,6 +56,7 @@ if(USE_CUDA) module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu module_hamilt_pw/hamilt_pwdft/kernels/cuda/meta_op.cu + module_hamilt_pw/hamilt_stodft/kernels/cuda/hpsi_norm_op.cu module_basis/module_pw/kernels/cuda/pw_op.cu module_hsolver/kernels/cuda/dngvd_op.cu module_hsolver/kernels/cuda/math_kernel_op.cu @@ -78,6 +80,7 @@ if(USE_ROCM) module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu module_hamilt_pw/hamilt_pwdft/kernels/rocm/ekinetic_op.hip.cu module_hamilt_pw/hamilt_pwdft/kernels/rocm/meta_op.hip.cu + module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu module_basis/module_pw/kernels/rocm/pw_op.hip.cu module_hsolver/kernels/rocm/dngvd_op.hip.cu module_hsolver/kernels/rocm/math_kernel_op.hip.cu diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 972a2e4269..dbd695e696 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -49,6 +49,7 @@ VPATH=./src_global:\ ./module_hamilt_pw/hamilt_stodft:\ ./module_hamilt_pw/hamilt_pwdft/operator_pw:\ ./module_hamilt_pw/hamilt_pwdft/kernels:\ +./module_hamilt_pw/hamilt_stodft/kernels:\ ./module_hamilt_lcao/module_hcontainer:\ ./module_hamilt_lcao/hamilt_lcaodft:\ ./module_hamilt_lcao/module_tddft:\ @@ -296,6 +297,7 @@ OBJS_HAMILT=hamilt_pw.o\ operator_pw.o\ ekinetic_pw.o\ ekinetic_op.o\ + hpsi_norm_op.o\ veff_pw.o\ veff_op.o\ nonlocal_pw.o\ diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index e11141208c..14deb8213a 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -61,6 +61,8 @@ add_library( ${LIBM_SRC} ) +target_link_libraries(base PUBLIC container) + add_subdirectory(module_container) if(ENABLE_COVERAGE) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 30b3b93d40..61ea4b390f 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -154,6 +154,15 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons #endif } +void BlasConnector::gemv(const char trans, const int m, const int n, + const float alpha, const float* A, const int lda, const float* X, const int incx, + const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); +} +} + void BlasConnector::gemv(const char trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* X, const int incx, const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type) diff --git a/source/module_base/blas_connector.h b/source/module_base/blas_connector.h index b819b6852e..e7b78bab06 100644 --- a/source/module_base/blas_connector.h +++ b/source/module_base/blas_connector.h @@ -40,6 +40,9 @@ extern "C" double dznrm2_( const int *n, const std::complex *X, const int *incX ); // level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work. + void sgemv_(const char*const transa, const int*const m, const int*const n, + const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx, + const float*const beta, float*const y, const int*const incy); void dgemv_(const char*const transa, const int*const m, const int*const n, const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx, const double*const beta, double*const y, const int*const incy); @@ -178,6 +181,11 @@ class BlasConnector const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static + void gemv(const char trans, const int m, const int n, + const float alpha, const float* A, const int lda, const float* X, const int incx, + const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static void gemv(const char trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* X, const int incx, diff --git a/source/module_base/math_chebyshev.cpp b/source/module_base/math_chebyshev.cpp index 76400ab630..9bfac7cac9 100644 --- a/source/module_base/math_chebyshev.cpp +++ b/source/module_base/math_chebyshev.cpp @@ -3,6 +3,7 @@ #include "blas_connector.h" #include "constants.h" #include "global_function.h" +#include "module_base/module_container/ATen/kernels/blas.h" #include "tool_quit.h" #include @@ -49,35 +50,55 @@ void FFTW::execute_fftw() // A number to control the number of grids in C_n integration #define EXTEND 16 -template -Chebyshev::Chebyshev(const int norder_in) : fftw(2 * EXTEND * norder_in) +template +Chebyshev::Chebyshev(const int norder_in) : fftw(2 * EXTEND * norder_in) { this->norder = norder_in; norder2 = 2 * norder * EXTEND; if (this->norder < 1) { - ModuleBase::WARNING_QUIT("Stochastic_Chebychev", "The Chebyshev expansion order should be at least 1!"); + ModuleBase::WARNING_QUIT("Chebyshev", "The Chebyshev expansion order should be at least 1!"); + } + coefr_cpu = new REAL[norder]; + coefc_cpu = new std::complex[norder]; + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + resmem_var_op()(this->ctx, this->coef_real, norder); + resmem_complex_op()(this->ctx, this->coef_complex, norder); + } + else + { + coef_real = coefr_cpu; + coef_complex = coefc_cpu; } polytrace = new REAL[norder]; - coef_real = new REAL[norder]; - coef_complex = new std::complex[norder]; // ndmin = ndmax = ndmax_in; - getcoef_complex = false; getcoef_real = false; } -template -Chebyshev::~Chebyshev() +template +Chebyshev::~Chebyshev() { delete[] polytrace; - delete[] coef_real; - delete[] coef_complex; + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + delmem_var_op()(this->ctx, this->coef_real); + delmem_complex_op()(this->ctx, this->coef_complex); + } + else + { + coef_real = nullptr; + coef_complex = nullptr; + } + + delete[] coefr_cpu; + delete[] coefc_cpu; } -template -void Chebyshev::getpolyval(const REAL x, REAL* polyval, const int N) +template +void Chebyshev::getpolyval(const REAL x, REAL* polyval, const int N) { polyval[0] = 1; polyval[1] = x; @@ -86,46 +107,57 @@ void Chebyshev::getpolyval(const REAL x, REAL* polyval, const int N) polyval[i] = 2 * x * polyval[i - 1] - polyval[i - 2]; } } -template -inline REAL Chebyshev::recurs(const REAL x, const REAL Tn, REAL const Tn_1) +template +inline REAL Chebyshev::recurs(const REAL x, const REAL Tn, REAL const Tn_1) { return 2 * x * Tn - Tn_1; } -template -REAL Chebyshev::ddot_real(const std::complex* psi_L, - const std::complex* psi_R, - const int N, - const int LDA, - const int m) +template +REAL Chebyshev::ddot_real(const std::complex* psi_L, + const std::complex* psi_R, + const int N, + const int LDA, + const int m) { REAL result = 0; + const base_device::DEVICE_CPU* cpu_ctx = {}; if (N == LDA || m == 1) { int dim2 = 2 * N * m; REAL *pL, *pR; pL = (REAL*)psi_L; pR = (REAL*)psi_R; - result = BlasConnector::dot(dim2, pL, 1, pR, 1); + REAL* dot_device = nullptr; + resmem_var_op()(this->ctx, dot_device, 1); + container::kernels::blas_dot()(dim2, pL, 1, pR, 1, dot_device); + syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result, dot_device, 1); + delmem_var_op()(this->ctx, dot_device); } else { REAL *pL, *pR; pL = (REAL*)psi_L; pR = (REAL*)psi_R; + REAL* dot_device = nullptr; + resmem_var_op()(this->ctx, dot_device, 1); for (int i = 0; i < m; ++i) { int dim2 = 2 * N; - result += BlasConnector::dot(dim2, pL, 1, pR, 1); + container::kernels::blas_dot()(dim2, pL, 1, pR, 1, dot_device); + REAL result_temp = 0; + syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result_temp, dot_device, 1); + result += result_temp; pL += 2 * LDA; pR += 2 * LDA; } + delmem_var_op()(this->ctx, dot_device); } return result; } -template -void Chebyshev::calcoef_real(std::function fun) +template +void Chebyshev::calcoef_real(std::function fun) { std::complex* pcoef = (std::complex*)this->fftw.ccoef; @@ -146,11 +178,11 @@ void Chebyshev::calcoef_real(std::function fun) REAL phi = i * ModuleBase::PI / norder2; if (i == 0) { - coef_real[i] = (cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3; + coefr_cpu[i] = (cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3; } else { - coef_real[i] = (cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3; + coefr_cpu[i] = (cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3; } } @@ -169,20 +201,25 @@ void Chebyshev::calcoef_real(std::function fun) { if (i == 0) { - coef_real[i] += real(pcoef[i]) / norder2 * 1 / 3; + coefr_cpu[i] += real(pcoef[i]) / norder2 * 1 / 3; } else { - coef_real[i] += real(pcoef[i]) / norder2 * 2 / 3; + coefr_cpu[i] += real(pcoef[i]) / norder2 * 2 / 3; } } + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, coef_real, coefr_cpu, norder); + } + getcoef_real = true; return; } -template -void Chebyshev::calcoef_complex(std::function(std::complex)> fun) +template +void Chebyshev::calcoef_complex(std::function(std::complex)> fun) { std::complex* pcoef = (std::complex*)this->fftw.ccoef; @@ -200,11 +237,11 @@ void Chebyshev::calcoef_complex(std::function(std::comp REAL phi = i * ModuleBase::PI / norder2; if (i == 0) { - coef_complex[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); + coefc_cpu[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); } else { - coef_complex[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); + coefc_cpu[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); } } @@ -218,11 +255,11 @@ void Chebyshev::calcoef_complex(std::function(std::comp REAL phi = i * ModuleBase::PI / norder2; if (i == 0) { - coef_complex[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); + coefc_cpu[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); } else { - coef_complex[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); + coefc_cpu[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); } } @@ -238,11 +275,11 @@ void Chebyshev::calcoef_complex(std::function(std::comp { if (i == 0) { - coef_complex[i].real(real(coef_complex[i]) + real(pcoef[i]) / norder2 * 1 / 3); + coefc_cpu[i].real(real(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 1 / 3); } else { - coef_complex[i].real(real(coef_complex[i]) + real(pcoef[i]) / norder2 * 2 / 3); + coefc_cpu[i].real(real(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } @@ -255,20 +292,24 @@ void Chebyshev::calcoef_complex(std::function(std::comp { if (i == 0) { - coef_complex[i].imag(imag(coef_complex[i]) + real(pcoef[i]) / norder2 * 1 / 3); + coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 1 / 3); } else { - coef_complex[i].imag(imag(coef_complex[i]) + real(pcoef[i]) / norder2 * 2 / 3); + coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder); + } getcoef_complex = true; return; } -template -void Chebyshev::calcoef_pair(std::function fun1, std::function fun2) +template +void Chebyshev::calcoef_pair(std::function fun1, std::function fun2) { std::complex* pcoef = (std::complex*)this->fftw.ccoef; @@ -286,11 +327,11 @@ void Chebyshev::calcoef_pair(std::function fun1, std::function REAL phi = i * ModuleBase::PI / norder2; if (i == 0) { - coef_complex[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); + coefc_cpu[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); } else { - coef_complex[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); + coefc_cpu[i].real((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); } } @@ -304,11 +345,11 @@ void Chebyshev::calcoef_pair(std::function fun1, std::function REAL phi = i * ModuleBase::PI / norder2; if (i == 0) { - coef_complex[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); + coefc_cpu[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 2 / 3); } else { - coef_complex[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); + coefc_cpu[i].imag((cos(phi) * pcoef[i].real() + sin(phi) * pcoef[i].imag()) / norder2 * 4 / 3); } } @@ -324,11 +365,11 @@ void Chebyshev::calcoef_pair(std::function fun1, std::function { if (i == 0) { - coef_complex[i].real(real(coef_complex[i]) + real(pcoef[i]) / norder2 * 1 / 3); + coefc_cpu[i].real(real(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 1 / 3); } else { - coef_complex[i].real(real(coef_complex[i]) + real(pcoef[i]) / norder2 * 2 / 3); + coefc_cpu[i].real(real(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } @@ -341,132 +382,159 @@ void Chebyshev::calcoef_pair(std::function fun1, std::function { if (i == 0) { - coef_complex[i].imag(imag(coef_complex[i]) + real(pcoef[i]) / norder2 * 1 / 3); + coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 1 / 3); } else { - coef_complex[i].imag(imag(coef_complex[i]) + real(pcoef[i]) / norder2 * 2 / 3); + coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, coef_complex, coefc_cpu, norder); + } + getcoef_complex = true; return; } -template -void Chebyshev::calfinalvec_real(std::function*, std::complex*, const int)> funA, - std::complex* wavein, - std::complex* waveout, - const int N, - const int LDA, - const int m) +template +void Chebyshev::calfinalvec_real( + std::function*, std::complex*, const int)> funA, + std::complex* wavein, + std::complex* waveout, + const int N, + const int LDA, + const int m) { - if (!getcoef_real) { + if (!getcoef_real) + { ModuleBase::WARNING_QUIT("Chebyshev", "Please calculate coef_real first!"); -} + } - std::complex* arraynp1; - std::complex* arrayn; - std::complex* arrayn_1; + std::complex* arraynp1 = nullptr; + std::complex* arrayn = nullptr; + std::complex* arrayn_1 = nullptr; assert(N >= 0 && LDA >= N); int ndmxt; - if (m == 1) { + if (m == 1) + { ndmxt = N * m; - } else { + } + else + { ndmxt = LDA * m; -} + } - arraynp1 = new std::complex[ndmxt]; - arrayn = new std::complex[ndmxt]; - arrayn_1 = new std::complex[ndmxt]; + resmem_complex_op()(this->ctx, arraynp1, ndmxt); + resmem_complex_op()(this->ctx, arrayn, ndmxt); + resmem_complex_op()(this->ctx, arrayn_1, ndmxt); - ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt); + memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt); + // ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt); funA(arrayn_1, arrayn, m); // 0- & 1-st order - for (int i = 0; i < ndmxt; ++i) - { - waveout[i] = coef_real[0] * arrayn_1[i] + coef_real[1] * arrayn[i]; - } + setmem_complex_op()(this->ctx, waveout, 0, ndmxt); + std::complex coef0 = std::complex(coefr_cpu[0], 0); + container::kernels::blas_axpy, ct_Device>()(ndmxt, &coef0, arrayn_1, 1, waveout, 1); + std::complex coef1 = std::complex(coefr_cpu[1], 0); + container::kernels::blas_axpy, ct_Device>()(ndmxt, &coef1, arrayn, 1, waveout, 1); + // for (int i = 0; i < ndmxt; ++i) + // { + // waveout[i] = coef_real[0] * arrayn_1[i] + coef_real[1] * arrayn[i]; + // } // more than 1-st orders for (int ior = 2; ior < norder; ++ior) { recurs_complex(funA, arraynp1, arrayn, arrayn_1, N, LDA, m); - for (int i = 0; i < ndmxt; ++i) - { - waveout[i] += coef_real[ior] * arraynp1[i]; - } + std::complex coefior = std::complex(coefr_cpu[ior], 0); + container::kernels::blas_axpy, ct_Device>()(ndmxt, &coefior, arraynp1, 1, waveout, 1); + // for (int i = 0; i < ndmxt; ++i) + // { + // waveout[i] += coef_real[ior] * arraynp1[i]; + // } std::complex* tem = arrayn_1; arrayn_1 = arrayn; arrayn = arraynp1; arraynp1 = tem; } - delete[] arraynp1; - delete[] arrayn; - delete[] arrayn_1; + delmem_complex_op()(this->ctx, arraynp1); + delmem_complex_op()(this->ctx, arrayn); + delmem_complex_op()(this->ctx, arrayn_1); return; } -template -void Chebyshev::calfinalvec_complex(std::function*, std::complex*, const int)> funA, - std::complex* wavein, - std::complex* waveout, - const int N, - const int LDA, - const int m) +template +void Chebyshev::calfinalvec_complex( + std::function*, std::complex*, const int)> funA, + std::complex* wavein, + std::complex* waveout, + const int N, + const int LDA, + const int m) { - if (!getcoef_complex) { - ModuleBase::WARNING_QUIT("Stochastic_Chebychev", "Please calculate coef_complex first!"); -} + if (!getcoef_complex) + { + ModuleBase::WARNING_QUIT("Chebyshev", "Please calculate coef_complex first!"); + } - std::complex* arraynp1; - std::complex* arrayn; - std::complex* arrayn_1; + std::complex* arraynp1 = nullptr; + std::complex* arrayn = nullptr; + std::complex* arrayn_1 = nullptr; assert(N >= 0 && LDA >= N); int ndmxt; - if (m == 1) { + if (m == 1) + { ndmxt = N * m; - } else { + } + else + { ndmxt = LDA * m; -} + } - arraynp1 = new std::complex[ndmxt]; - arrayn = new std::complex[ndmxt]; - arrayn_1 = new std::complex[ndmxt]; + resmem_complex_op()(this->ctx, arraynp1, ndmxt); + resmem_complex_op()(this->ctx, arrayn, ndmxt); + resmem_complex_op()(this->ctx, arrayn_1, ndmxt); - ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt); + memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt); funA(arrayn_1, arrayn, m); // 0- & 1-st order - for (int i = 0; i < ndmxt; ++i) - { - waveout[i] = coef_complex[0] * arrayn_1[i] + coef_complex[1] * arrayn[i]; - } + setmem_complex_op()(this->ctx, waveout, 0, ndmxt); + container::kernels::blas_axpy, ct_Device>()(ndmxt, &coefc_cpu[0], arrayn_1, 1, waveout, 1); + container::kernels::blas_axpy, ct_Device>()(ndmxt, &coefc_cpu[1], arrayn, 1, waveout, 1); + // for (int i = 0; i < ndmxt; ++i) + // { + // waveout[i] = coef_complex[0] * arrayn_1[i] + coef_complex[1] * arrayn[i]; + // } // more than 1-st orders for (int ior = 2; ior < norder; ++ior) { recurs_complex(funA, arraynp1, arrayn, arrayn_1, N, LDA, m); - for (int i = 0; i < ndmxt; ++i) - { - waveout[i] += coef_complex[ior] * arraynp1[i]; - } + container::kernels::blas_axpy, ct_Device>()(ndmxt, &coefc_cpu[ior], arraynp1, 1, waveout, 1); + // for (int i = 0; i < ndmxt; ++i) + // { + // waveout[i] += coef_complex[ior] * arraynp1[i]; + // } std::complex* tem = arrayn_1; arrayn_1 = arrayn; arrayn = arraynp1; arraynp1 = tem; } - delete[] arraynp1; - delete[] arrayn; - delete[] arrayn_1; + delmem_complex_op()(this->ctx, arraynp1); + delmem_complex_op()(this->ctx, arrayn); + delmem_complex_op()(this->ctx, arrayn_1); return; } -template -void Chebyshev::calpolyvec_complex( +template +void Chebyshev::calpolyvec_complex( std::function*, std::complex*, const int)> funA, std::complex* wavein, std::complex* polywaveout, @@ -485,7 +553,8 @@ void Chebyshev::calpolyvec_complex( std::complex*tmpin = wavein, *tmpout = arrayn_1; for (int i = 0; i < m; ++i) { - ModuleBase::GlobalFunc::DCOPY(tmpin, tmpout, N); + memcpy_complex_op()(this->ctx, this->ctx, tmpout, tmpin, N); + // ModuleBase::GlobalFunc::DCOPY(tmpin, tmpout, N); tmpin += LDA; tmpout += LDA; } @@ -504,29 +573,34 @@ void Chebyshev::calpolyvec_complex( return; } -template -void Chebyshev::tracepolyA(std::function* in, std::complex* out, const int)> funA, - std::complex* wavein, - const int N, - const int LDA, - const int m) +template +void Chebyshev::tracepolyA( + std::function* in, std::complex* out, const int)> funA, + std::complex* wavein, + const int N, + const int LDA, + const int m) { - std::complex* arraynp1; - std::complex* arrayn; - std::complex* arrayn_1; + std::complex* arraynp1 = nullptr; + std::complex* arrayn = nullptr; + std::complex* arrayn_1 = nullptr; assert(N >= 0 && LDA >= N); int ndmxt; - if (m == 1) { + if (m == 1) + { ndmxt = N * m; - } else { + } + else + { ndmxt = LDA * m; -} + } - arraynp1 = new std::complex[ndmxt]; - arrayn = new std::complex[ndmxt]; - arrayn_1 = new std::complex[ndmxt]; + resmem_complex_op()(this->ctx, arraynp1, ndmxt); + resmem_complex_op()(this->ctx, arrayn, ndmxt); + resmem_complex_op()(this->ctx, arrayn_1, ndmxt); - ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt); + memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, ndmxt); + // ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, ndmxt); funA(arrayn_1, arrayn, m); @@ -544,14 +618,14 @@ void Chebyshev::tracepolyA(std::function* in, std: arraynp1 = tem; } - delete[] arraynp1; - delete[] arrayn; - delete[] arrayn_1; + delmem_complex_op()(this->ctx, arraynp1); + delmem_complex_op()(this->ctx, arrayn); + delmem_complex_op()(this->ctx, arrayn_1); return; } -template -void Chebyshev::recurs_complex( +template +void Chebyshev::recurs_complex( std::function* in, std::complex* out, const int)> funA, std::complex* arraynp1, std::complex* arrayn, @@ -561,17 +635,27 @@ void Chebyshev::recurs_complex( const int m) { funA(arrayn, arraynp1, m); + const std::complex two = 2.0; + const std::complex invone = -1.0; for (int ib = 0; ib < m; ++ib) { - for (int i = 0; i < N; ++i) - { - arraynp1[i + ib * LDA] = REAL(2.0) * arraynp1[i + ib * LDA] - arrayn_1[i + ib * LDA]; - } + container::kernels::blas_scal, ct_Device>()(N, &two, arraynp1 + ib * LDA, 1); + container::kernels::blas_axpy, ct_Device>()(N, + &invone, + arrayn_1 + ib * LDA, + 1, + arraynp1 + ib * LDA, + 1); + + // for (int i = 0; i < N; ++i) + // { + // arraynp1[i + ib * LDA] = REAL(2.0) * arraynp1[i + ib * LDA] - arrayn_1[i + ib * LDA]; + // } } } -template -bool Chebyshev::checkconverge( +template +bool Chebyshev::checkconverge( std::function* in, std::complex* out, const int)> funA, std::complex* wavein, const int N, @@ -581,15 +665,16 @@ bool Chebyshev::checkconverge( REAL stept) { bool converge = true; - std::complex* arraynp1; - std::complex* arrayn; - std::complex* arrayn_1; + std::complex* arraynp1 = nullptr; + std::complex* arrayn = nullptr; + std::complex* arrayn_1 = nullptr; - arraynp1 = new std::complex[LDA]; - arrayn = new std::complex[LDA]; - arrayn_1 = new std::complex[LDA]; + resmem_complex_op()(this->ctx, arraynp1, LDA); + resmem_complex_op()(this->ctx, arrayn, LDA); + resmem_complex_op()(this->ctx, arrayn_1, LDA); - ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N); + memcpy_complex_op()(this->ctx, this->ctx, arrayn_1, wavein, N); + // ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N); if (tmin == tmax) { @@ -599,13 +684,21 @@ bool Chebyshev::checkconverge( funA(arrayn_1, arrayn, 1); REAL sum1, sum2; REAL t; + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + sum1 = this->ddot_real(arrayn_1, arrayn_1, N); + sum2 = this->ddot_real(arrayn_1, arrayn, N); + } + else + { #ifdef __MPI - sum1 = ModuleBase::GlobalFunc::ddot_real(N, arrayn_1, arrayn_1); - sum2 = ModuleBase::GlobalFunc::ddot_real(N, arrayn_1, arrayn); + sum1 = ModuleBase::GlobalFunc::ddot_real(N, arrayn_1, arrayn_1); + sum2 = ModuleBase::GlobalFunc::ddot_real(N, arrayn_1, arrayn); #else - sum1 = this->ddot_real(arrayn_1, arrayn_1, N); - sum2 = this->ddot_real(arrayn_1, arrayn, N); + sum1 = this->ddot_real(arrayn_1, arrayn_1, N); + sum2 = this->ddot_real(arrayn_1, arrayn, N); #endif + } t = sum2 / sum1 * (tmax - tmin) / 2 + (tmax + tmin) / 2; if (t < tmin || tmin == 0) { @@ -621,13 +714,21 @@ bool Chebyshev::checkconverge( for (int ior = 2; ior < norder; ++ior) { funA(arrayn, arraynp1, 1); + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + { + sum1 = this->ddot_real(arrayn, arrayn, N); + sum2 = this->ddot_real(arrayn, arraynp1, N); + } + else + { #ifdef __MPI - sum1 = ModuleBase::GlobalFunc::ddot_real(N, arrayn, arrayn); - sum2 = ModuleBase::GlobalFunc::ddot_real(N, arrayn, arraynp1); + sum1 = ModuleBase::GlobalFunc::ddot_real(N, arrayn, arrayn); + sum2 = ModuleBase::GlobalFunc::ddot_real(N, arrayn, arraynp1); #else - sum1 = this->ddot_real(arrayn, arrayn, N); - sum2 = this->ddot_real(arrayn, arraynp1, N); + sum1 = this->ddot_real(arrayn, arrayn, N); + sum2 = this->ddot_real(arrayn, arraynp1, N); #endif + } t = sum2 / sum1 * (tmax - tmin) / 2 + (tmax + tmin) / 2; if (t < tmin) { @@ -639,19 +740,23 @@ bool Chebyshev::checkconverge( converge = false; tmax = t + stept; } - for (int i = 0; i < N; ++i) - { - arraynp1[i] = REAL(2.0) * arraynp1[i] - arrayn_1[i]; - } + std::complex two = 2.0; + std::complex invone = -1.0; + container::kernels::blas_scal, ct_Device>()(N, &two, arraynp1, 1); + container::kernels::blas_axpy, ct_Device>()(N, &invone, arrayn_1, 1, arraynp1, 1); + // for (int i = 0; i < N; ++i) + // { + // arraynp1[i] = REAL(2.0) * arraynp1[i] - arrayn_1[i]; + // } std::complex* tem = arrayn_1; arrayn_1 = arrayn; arrayn = arraynp1; arraynp1 = tem; } - delete[] arraynp1; - delete[] arrayn; - delete[] arrayn_1; + delmem_complex_op()(this->ctx, arraynp1); + delmem_complex_op()(this->ctx, arrayn); + delmem_complex_op()(this->ctx, arrayn_1); return converge; } @@ -660,5 +765,8 @@ template class Chebyshev; #ifdef __ENABLE_FLOAT_FFTW template class Chebyshev; #endif +#if ((defined __CUDA) || (defined __ROCM)) +template class Chebyshev; +#endif } // namespace ModuleBase diff --git a/source/module_base/math_chebyshev.h b/source/module_base/math_chebyshev.h index a62a5520d2..122f2021e9 100644 --- a/source/module_base/math_chebyshev.h +++ b/source/module_base/math_chebyshev.h @@ -1,6 +1,9 @@ #ifndef STO_CHEBYCHEV_H #define STO_CHEBYCHEV_H #include "fftw3.h" +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" +#include "module_base/module_container/ATen/core/tensor_types.h" #include #include @@ -76,7 +79,7 @@ class FFTW; * //calculate vp1: |vp1> = 2 H|v> - |vm1>; * */ -template +template class Chebyshev { @@ -201,15 +204,18 @@ class Chebyshev int norder; // order of Chebyshev expansion int norder2; // 2 * norder * EXTEND - REAL* coef_real; // expansion coefficient of each order - std::complex* coef_complex; // expansion coefficient of each order - FFTW fftw; // use for fftw - REAL* polytrace; // w_n = \sum_i v^+ * T_n(A) * v + REAL* coef_real = nullptr; //[Device] expansion coefficient of each order + std::complex* coef_complex = nullptr; //[Device] expansion coefficient of each order + REAL* coefr_cpu = nullptr; //[CPU] expansion coefficient of each order + std::complex* coefc_cpu = nullptr; //[CPU] expansion coefficient of each order + + FFTW fftw; // use for fftw + REAL* polytrace; //[CPU] w_n = \sum_i v^+ * T_n(A) * v, only bool getcoef_real; // coef_real has been calculated bool getcoef_complex; // coef_complex has been calculated - private: + public: // SI. // calculate dot product REAL ddot_real(const std::complex* psi_L, @@ -217,6 +223,22 @@ class Chebyshev const int N, const int LDA = 1, const int m = 1); + + private: + Device* ctx = {}; + base_device::DEVICE_CPU* cpu_ctx = {}; + using ct_Device = typename container::PsiToContainer::type; + using resmem_complex_op = base_device::memory::resize_memory_op, Device>; + using resmem_var_op = base_device::memory::resize_memory_op; + using delmem_complex_op = base_device::memory::delete_memory_op, Device>; + using delmem_var_op = base_device::memory::delete_memory_op; + using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op; + using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op; + using syncmem_complex_h2d_op = base_device::memory::synchronize_memory_op, Device, base_device::DEVICE_CPU>; + using syncmem_complex_d2h_op = base_device::memory::synchronize_memory_op, base_device::DEVICE_CPU, Device>; + using memcpy_var_op = base_device::memory::synchronize_memory_op; + using memcpy_complex_op = base_device::memory::synchronize_memory_op, Device, Device>; + using setmem_complex_op = base_device::memory::set_memory_op, Device>; }; template <> @@ -244,6 +266,7 @@ class FFTW fftwf_plan coef_plan; }; #endif + } // namespace ModuleBase #endif diff --git a/source/module_base/module_container/CMakeLists.txt b/source/module_base/module_container/CMakeLists.txt index 439ca675d0..80a8dce5fc 100644 --- a/source/module_base/module_container/CMakeLists.txt +++ b/source/module_base/module_container/CMakeLists.txt @@ -19,7 +19,7 @@ if(USE_ROCM) set(ATen_ROCM_DEPENDENCY_LIBS container_rocm) endif() -add_library(container OBJECT ${ATen_CPU_SRCS} ${ATen_CUDA_SRCS}) +add_library(container STATIC ${ATen_CPU_SRCS} ${ATen_CUDA_SRCS}) target_link_libraries(container PUBLIC ${ATen_CPU_DEPENDENCY_LIBS} ${ATen_CUDA_DEPENDENCY_LIBS} ${ATen_ROCM_DEPENDENCY_LIBS}) diff --git a/source/module_base/module_device/cuda/memory_op.cu b/source/module_base/module_device/cuda/memory_op.cu index 29be90a612..eaeb505071 100644 --- a/source/module_base/module_device/cuda/memory_op.cu +++ b/source/module_base/module_device/cuda/memory_op.cu @@ -38,6 +38,18 @@ __global__ void cast_memory(std::complex* out, const std::complex>(_in[idx]); } +template +__global__ void cast_memory(std::complex* out, const FPTYPE_in* in, const int size) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) + { + return; + } + auto* _out = reinterpret_cast*>(out); + _out[idx] = static_cast>(in[idx]); +} + template void resize_memory_op::operator()(const base_device::DEVICE_GPU* dev, FPTYPE*& arr, @@ -223,6 +235,8 @@ template struct cast_memory_op, std::complex, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; +template struct cast_memory_op, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; +template struct cast_memory_op, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; template struct cast_memory_op; template struct cast_memory_op; template struct cast_memory_op; diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index f989924d30..68146c275a 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -138,6 +138,8 @@ template struct cast_memory_op, std::complex, base_device::DEVICE_CPU, base_device::DEVICE_CPU>; +template struct cast_memory_op, float, base_device::DEVICE_CPU, base_device::DEVICE_CPU>; +template struct cast_memory_op, double, base_device::DEVICE_CPU, base_device::DEVICE_CPU>; template struct delete_memory_op; template struct delete_memory_op; diff --git a/source/module_base/parallel_device.h b/source/module_base/parallel_device.h new file mode 100644 index 0000000000..8d867ba4fc --- /dev/null +++ b/source/module_base/parallel_device.h @@ -0,0 +1,111 @@ +#ifdef __MPI +#include "mpi.h" +#include "module_base/module_device/device.h" +#include +#include +#include +namespace Parallel_Common +{ +void bcast_complex(std::complex* object, const int& n, const MPI_Comm& comm) +{ + MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm); +} +void bcast_complex(std::complex* object, const int& n, const MPI_Comm& comm) +{ + MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm); +} +void bcast_real(double* object, const int& n, const MPI_Comm& comm) +{ + MPI_Bcast(object, n, MPI_DOUBLE, 0, comm); +} +void bcast_real(float* object, const int& n, const MPI_Comm& comm) +{ + MPI_Bcast(object, n, MPI_FLOAT, 0, comm); +} + +template +/** + * @brief bcast complex in Device + * + * @param ctx Device ctx + * @param object complex arrays in Device + * @param n the size of complex arrays + * @param comm MPI_Comm + * @param tmp_space tmp space in CPU + */ +void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +{ + const base_device::DEVICE_CPU* cpu_ctx = {}; + T* object_cpu = nullptr; + bool alloc = false; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + if(tmp_space == nullptr) + { + base_device::memory::resize_memory_op()(cpu_ctx, object_cpu, n); + alloc = true; + } + else + { + object_cpu = tmp_space; + } + base_device::memory::synchronize_memory_op()(cpu_ctx, ctx, object_cpu, object, n); + } + else + { + object_cpu = object; + } + + bcast_complex(object_cpu, n, comm); + + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + base_device::memory::synchronize_memory_op()(ctx, cpu_ctx, object, object_cpu, n); + if(alloc) + { + base_device::memory::delete_memory_op()(cpu_ctx, object_cpu); + } + } + return; +} + +template +void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +{ + const base_device::DEVICE_CPU* cpu_ctx = {}; + T* object_cpu = nullptr; + bool alloc = false; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + if(tmp_space == nullptr) + { + base_device::memory::resize_memory_op()(cpu_ctx, object_cpu, n); + alloc = true; + } + else + { + object_cpu = tmp_space; + } + base_device::memory::synchronize_memory_op()(cpu_ctx, ctx, object_cpu, object, n); + } + else + { + object_cpu = object; + } + + bcast_real(object_cpu, n, comm); + + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + base_device::memory::synchronize_memory_op()(ctx, cpu_ctx, object, object_cpu, n); + if(alloc) + { + base_device::memory::delete_memory_op()(cpu_ctx, object_cpu); + } + } + return; +} +} + + +#endif \ No newline at end of file diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 86e05d9478..786c72ca4d 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -134,8 +134,8 @@ AddTest( AddTest( TARGET base_math_chebyshev - LIBS parameter ${math_libs} - SOURCES math_chebyshev_test.cpp ../blas_connector.cpp ../math_chebyshev.cpp ../tool_quit.cpp ../global_variable.cpp ../timer.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp + LIBS parameter ${math_libs} device container + SOURCES math_chebyshev_test.cpp ../blas_connector.cpp ../math_chebyshev.cpp ../tool_quit.cpp ../global_variable.cpp ../timer.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../parallel_reduce.cpp ) AddTest( diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index aebde68a08..d73163da77 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -39,6 +39,10 @@ ElecStatePW::~ElecStatePW() delmem_var_op()(this->ctx, this->kin_r_data); } } + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { + delete[] this->rho; + delete[] this->kin_r; + } delmem_var_op()(this->ctx, becsum); delmem_complex_op()(this->ctx, this->wfcr); delmem_complex_op()(this->ctx, this->wfcr_another_spin); @@ -47,6 +51,10 @@ ElecStatePW::~ElecStatePW() template void ElecStatePW::init_rho_data() { + if(this->init_rho) { + return; + } + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { this->rho = new Real*[this->charge->nspin]; resmem_var_op()(this->ctx, this->rho_data, this->charge->nspin * this->charge->nrxx); @@ -80,9 +88,7 @@ void ElecStatePW::psiToRho(const psi::Psi& psi) ModuleBase::TITLE("ElecStatePW", "psiToRho"); ModuleBase::timer::tick("ElecStatePW", "psiToRho"); - if (!this->init_rho) { - this->init_rho_data(); - } + this->init_rho_data(); this->calculate_weights(); this->calEBand(); @@ -154,9 +160,7 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) // if (PARAM.inp.nspin == 4) // wfcr_another_spin.resize(this->charge->nrxx); - if (!this->init_rho) { - this->init_rho_data(); - } + this->init_rho_data(); int ik = psi.get_current_k(); int npw = psi.get_current_nbas(); int current_spin = 0; diff --git a/source/module_elecstate/elecstate_pw.h b/source/module_elecstate/elecstate_pw.h index 73cf6c91a2..408dd76a96 100644 --- a/source/module_elecstate/elecstate_pw.h +++ b/source/module_elecstate/elecstate_pw.h @@ -39,6 +39,9 @@ class ElecStatePW : public ElecState // void getNewRho() override; Real* becsum = nullptr; + // init rho_data and kin_r_data + void init_rho_data(); + Real ** rho = nullptr, ** kin_r = nullptr; //[Device] [spin][nrxx] rho and kin_r protected: ModulePW::PW_Basis* rhopw_smooth = nullptr; @@ -58,15 +61,12 @@ class ElecStatePW : public ElecState // \sum_lm Q_lm(r) \sum_i w_i void addusdens_g(const Real* becsum, T* rhog); - void init_rho_data(); - Device * ctx = {}; bool init_rho = false; mutable T* vkb = nullptr; - Real ** rho = nullptr, ** kin_r = nullptr; Real * rho_data = nullptr, * kin_r_data = nullptr; T * wfcr = nullptr, * wfcr_another_spin = nullptr; - + private: using meta_op = hamilt::meta_pw_op; using elecstate_pw_op = elecstate::elecstate_pw_op; diff --git a/source/module_elecstate/elecstate_pw_sdft.cpp b/source/module_elecstate/elecstate_pw_sdft.cpp index c36513c838..d534ac95f5 100644 --- a/source/module_elecstate/elecstate_pw_sdft.cpp +++ b/source/module_elecstate/elecstate_pw_sdft.cpp @@ -9,26 +9,19 @@ namespace elecstate { template -void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) +void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) { ModuleBase::TITLE(this->classname, "psiToRho"); ModuleBase::timer::tick(this->classname, "psiToRho"); - for (int is = 0; is < PARAM.inp.nspin; is++) - { - ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); - if (XC_Functional::get_func_type() == 3) - { - ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); - } - } + const int nspin = PARAM.inp.nspin; if (GlobalV::MY_STOGROUP == 0) { this->calEBand(); - for (int is = 0; is < PARAM.inp.nspin; is++) + for (int is = 0; is < nspin; is++) { - ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); + setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx); } for (int ik = 0; ik < psi.get_nk(); ++ik) @@ -36,6 +29,11 @@ void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) psi.fix_k(ik); this->updateRhoK(psi); } + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { + for (int ii = 0; ii < nspin; ii++) { + castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx); + } + } this->parallelK(); } ModuleBase::timer::tick(this->classname, "psiToRho"); @@ -44,4 +42,7 @@ void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) // template class ElecStatePW_SDFT, base_device::DEVICE_CPU>; template class ElecStatePW_SDFT, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class ElecStatePW_SDFT, base_device::DEVICE_GPU>; +#endif } // namespace elecstate \ No newline at end of file diff --git a/source/module_elecstate/elecstate_pw_sdft.h b/source/module_elecstate/elecstate_pw_sdft.h index eaf9215390..fe9ab81834 100644 --- a/source/module_elecstate/elecstate_pw_sdft.h +++ b/source/module_elecstate/elecstate_pw_sdft.h @@ -20,7 +20,11 @@ class ElecStatePW_SDFT : public ElecStatePW { this->classname = "ElecStatePW_SDFT"; } - virtual void psiToRho(const psi::Psi& psi) override; + virtual void psiToRho(const psi::Psi& psi) override; + private: + using Real = typename GetTypeReal::type; + using setmem_var_op = base_device::memory::set_memory_op; + using castmem_var_d2h_op = base_device::memory::cast_memory_op; }; } // namespace elecstate #endif \ No newline at end of file diff --git a/source/module_esolver/esolver.cpp b/source/module_esolver/esolver.cpp index 74ef101579..31cced7758 100644 --- a/source/module_esolver/esolver.cpp +++ b/source/module_esolver/esolver.cpp @@ -150,6 +150,19 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell) } else if (esolver_type == "sdft_pw") { +#if ((defined __CUDA) || (defined __ROCM)) + if (PARAM.inp.device == "gpu") + { + // if (PARAM.inp.precision == "single") + // { + // return new ESolver_SDFT_PW, base_device::DEVICE_GPU>(); + // } + // else + // { + return new ESolver_SDFT_PW, base_device::DEVICE_GPU>(); + // } + } +#endif // if (PARAM.inp.precision == "single") // { // return new ESolver_SDFT_PW, base_device::DEVICE_CPU>(); diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 2e21355208..023f800443 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -206,7 +206,7 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr hsolver::DiagoIterAssist::need_subspace, this->init_psi); - hsolver_pw_sdft_obj.solve(this->p_hamilt, this->psi[0], this->pelec, this->pw_wfc, this->stowf, istep, iter, false); + hsolver_pw_sdft_obj.solve(this->p_hamilt, this->kspw_psi[0], this->psi[0], this->pelec, this->pw_wfc, this->stowf, istep, iter, false); this->init_psi = true; // set_diagethr need it @@ -241,8 +241,8 @@ double ESolver_SDFT_PW::cal_energy() return this->pelec->f_en.etot; } -template -void ESolver_SDFT_PW::cal_force(ModuleBase::matrix& force) +template <> +void ESolver_SDFT_PW, base_device::DEVICE_CPU>::cal_force(ModuleBase::matrix& force) { Sto_Forces ff(GlobalC::ucell.nat); @@ -257,8 +257,14 @@ void ESolver_SDFT_PW::cal_force(ModuleBase::matrix& force) this->stowf); } -template -void ESolver_SDFT_PW::cal_stress(ModuleBase::matrix& stress) +template <> +void ESolver_SDFT_PW, base_device::DEVICE_GPU>::cal_force(ModuleBase::matrix& force) +{ + ModuleBase::WARNING_QUIT("ESolver_SDFT_PW::cal_force", "DEVICE_GPU is not supported"); +} + +template <> +void ESolver_SDFT_PW, base_device::DEVICE_CPU>::cal_stress(ModuleBase::matrix& stress) { Sto_Stress_PW ss; ss.cal_stress(stress, @@ -275,6 +281,12 @@ void ESolver_SDFT_PW::cal_stress(ModuleBase::matrix& stress) GlobalC::ucell); } +template <> +void ESolver_SDFT_PW, base_device::DEVICE_GPU>::cal_stress(ModuleBase::matrix& stress) +{ + ModuleBase::WARNING_QUIT("ESolver_SDFT_PW::cal_stress", "DEVICE_GPU is not supported"); +} + template void ESolver_SDFT_PW::after_all_runners() { @@ -379,4 +391,8 @@ void ESolver_SDFT_PW::nscf() // template class ESolver_SDFT_PW, base_device::DEVICE_CPU>; template class ESolver_SDFT_PW, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +// template class ESolver_SDFT_PW, base_device::DEVICE_GPU>; +template class ESolver_SDFT_PW, base_device::DEVICE_GPU>; +#endif } // namespace ModuleESolver diff --git a/source/module_esolver/esolver_sdft_pw.h b/source/module_esolver/esolver_sdft_pw.h index 71e5670c44..63e613befc 100644 --- a/source/module_esolver/esolver_sdft_pw.h +++ b/source/module_esolver/esolver_sdft_pw.h @@ -13,6 +13,8 @@ namespace ModuleESolver template class ESolver_SDFT_PW : public ESolver_KS_PW { + private: + using Real = typename GetTypeReal::type; public: ESolver_SDFT_PW(); ~ESolver_SDFT_PW(); @@ -27,7 +29,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW public: Stochastic_WF stowf; - StoChe stoche; + StoChe stoche; hamilt::HamiltSdftPW* p_hamilt_sto = nullptr; protected: diff --git a/source/module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.cpp b/source/module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.cpp index 272a8c1539..7f0077871f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.cpp @@ -1,5 +1,6 @@ #include "hamilt_sdft_pw.h" #include "module_base/timer.h" +#include "kernels/hpsi_norm_op.h" namespace hamilt { @@ -56,24 +57,14 @@ void HamiltSdftPW::hPsi_norm(const T* psi_in, T* hpsi_norm, const int const Real Ebar = (emin + emax) / 2; const Real DeltaE = (emax - emin) / 2; -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int ib = 0; ib < nbands; ++ib) - { - const int ig0 = ib * npwk_max; - for (int ig = 0; ig < npwk; ++ig) - { - hpsi_norm[ig + ig0] = (hpsi_norm[ig + ig0] - Ebar * psi_in[ig + ig0]) / DeltaE; - } - } + hpsi_norm_op()(this->ctx, nbands, npwk_max, npwk, Ebar, DeltaE, hpsi_norm, psi_in); ModuleBase::timer::tick("HamiltSdftPW", "hPsi_norm"); } -template class HamiltSdftPW, base_device::DEVICE_CPU>; +// template class HamiltSdftPW, base_device::DEVICE_CPU>; template class HamiltSdftPW, base_device::DEVICE_CPU>; #if ((defined __CUDA) || (defined __ROCM)) -template class HamiltSdftPW, base_device::DEVICE_GPU>; +// template class HamiltSdftPW, base_device::DEVICE_GPU>; template class HamiltSdftPW, base_device::DEVICE_GPU>; #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/kernels/cuda/hpsi_norm_op.cu b/source/module_hamilt_pw/hamilt_stodft/kernels/cuda/hpsi_norm_op.cu new file mode 100644 index 0000000000..608ec886e6 --- /dev/null +++ b/source/module_hamilt_pw/hamilt_stodft/kernels/cuda/hpsi_norm_op.cu @@ -0,0 +1,49 @@ +#include "module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.h" + +#include +#include +#include + +namespace hamilt +{ +#define THREADS_PER_BLOCK 256 + +template +__global__ void hpsi_norm(const int npwk_max, + const int npwk, + const FPTYPE Ebar, + const FPTYPE DeltaE, + thrust::complex* hpsi, + const thrust::complex* psi_in) +{ + const int block_idx = blockIdx.x; + const int thread_idx = threadIdx.x; + const int start_idx = block_idx * npwk_max; + for (int ii = thread_idx; ii < npwk; ii += blockDim.x) + { + hpsi[start_idx + ii] = (hpsi[start_idx + ii] - Ebar * psi_in[start_idx + ii]) / DeltaE; + } +} + +template +void hamilt::hpsi_norm_op::operator()(const base_device::DEVICE_GPU* dev, + const int& nbands, + const int& npwk_max, + const int& npwk, + const FPTYPE& Ebar, + const FPTYPE& DeltaE, + std::complex* hpsi, + const std::complex* psi_in) +{ + // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + hpsi_norm<<>>( + npwk_max, npwk, Ebar, DeltaE, + reinterpret_cast*>(hpsi), + reinterpret_cast*>(psi_in)); + cudaCheckOnDebug(); +} + +template struct hpsi_norm_op; +template struct hpsi_norm_op; + +} // namespace hamilt \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp b/source/module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp new file mode 100644 index 0000000000..0457c0e98e --- /dev/null +++ b/source/module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp @@ -0,0 +1,35 @@ +#include "hpsi_norm_op.h" + +#include "module_base/module_device/device.h" +namespace hamilt +{ +template +struct hpsi_norm_op +{ + void operator()(const base_device::DEVICE_CPU* dev, + const int& nbands, + const int& npwk_max, + const int& npwk, + const FPTYPE& Ebar, + const FPTYPE& DeltaE, + std::complex* hpsi_norm, + const std::complex* psi_in) + { +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (int ib = 0; ib < nbands; ++ib) + { + const int ig0 = ib * npwk_max; + for (int ig = 0; ig < npwk; ++ig) + { + hpsi_norm[ig + ig0] = (hpsi_norm[ig + ig0] - Ebar * psi_in[ig + ig0]) / DeltaE; + } + } + } +}; + +// template struct hpsi_norm_op; +template struct hpsi_norm_op; + +} // namespace hamilt \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.h b/source/module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.h new file mode 100644 index 0000000000..b422916ae8 --- /dev/null +++ b/source/module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.h @@ -0,0 +1,48 @@ +#ifndef HPSI_NORM_OP_H +#define HPSI_NORM_OP_H +#include +#include "module_base/module_device/device.h" +namespace hamilt +{ +template +struct hpsi_norm_op +{ + /// @brief normalize hPsi with emin and emax + /// + /// Input Parameters + /// \param dev : the type of computing device + /// \param nbands : nbands + /// \param npwk_max : max number of planewaves of all k points + /// \param npwk : number of planewaves of current k point + /// \param Ebar : (emin + emax) / 2 + /// \param DeltaE : (emax - emin) / 2 + /// \param hpsi_norm : hPsi + /// \param psi_in : input psi + /// Output Parameters + /// \param tmhpsi : output array + void operator()(const Device* dev, + const int& nbands, + const int& npwk_max, + const int& npwk, + const FPTYPE& Ebar, + const FPTYPE& DeltaE, + std::complex* hpsi_norm, + const std::complex* psi_in); +}; +#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM +// Partially specialize functor for base_device::GpuDevice. +template +struct hpsi_norm_op +{ + void operator()(const base_device::DEVICE_GPU* dev, + const int& nbands, + const int& npwk_max, + const int& npwk, + const FPTYPE& Ebar, + const FPTYPE& DeltaE, + std::complex* hpsi_norm, + const std::complex* psi_in); +}; +#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM +} // namespace hamilt +#endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu b/source/module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu new file mode 100644 index 0000000000..6566e5610b --- /dev/null +++ b/source/module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu @@ -0,0 +1,54 @@ +#include "module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.h" + +#include + +#include +#include + +namespace hamilt +{ +#define THREADS_PER_BLOCK 256 + +template +__global__ void hpsi_norm(const int npwk_max, + const int npwk, + const FPTYPE Ebar, + const FPTYPE DeltaE, + thrust::complex* hpsi, + const thrust::complex* psi_in) +{ + const int block_idx = blockIdx.x; + const int thread_idx = threadIdx.x; + const int start_idx = block_idx * npwk_max; + for (int ii = thread_idx; ii < npwk; ii += blockDim.x) + { + hpsi[start_idx + ii] = (hpsi[start_idx + ii] - Ebar * psi_in[start_idx + ii]) / DeltaE; + } +} + +template +void hamilt::hpsi_norm_op::operator()(const base_device::DEVICE_GPU* dev, + const int& nbands, + const int& npwk_max, + const int& npwk, + const FPTYPE& Ebar, + const FPTYPE& DeltaE, + std::complex* hpsi, + const std::complex* psi_in) +{ + // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + // hpsi_norm<<>>( + // npwk_max, npwk, + // reinterpret_cast*>(hpsi_norm), + // reinterpret_cast*>(psi_in)); + hipLaunchKernelGGL(HIP_KERNEL_NAME(hpsi_norm), dim3(nbands), dim3(THREADS_PER_BLOCK), 0, 0, + npwk_max, npwk, Ebar, DeltaE, + reinterpret_cast*>(hpsi), + reinterpret_cast*>(psi_in)); + cudaCheckOnDebug(); +} + +template struct hpsi_norm_op; +template struct hpsi_norm_op; + +} // namespace hamilt \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp index 2f404b9c9c..23a5a18926 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp @@ -1,44 +1,38 @@ #include "sto_che.h" #include "module_base/blas_connector.h" +#include "module_base/module_device/device.h" +#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/module_container/ATen/kernels/blas.h" -template -StoChe::~StoChe() +template +StoChe::~StoChe() { delete p_che; - delete[] spolyv; + delete[] spolyv_cpu; + delmem_var_op()(this->ctx, spolyv); } -template -StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) +template +StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) { this->nche = nche; this->method_sto = method; - p_che = new ModuleBase::Chebyshev(nche); + p_che = new ModuleBase::Chebyshev(nche); if (method == 1) { - spolyv = new REAL[nche]; + resmem_var_op()(this->ctx, spolyv, nche); + spolyv_cpu = new REAL[nche]; } else { - spolyv = new REAL[nche * nche]; + resmem_var_op()(this->ctx, spolyv, nche * nche); } this->emax_sto = emax_sto; this->emin_sto = emin_sto; } -template class StoChe; -// template class StoChe; - -double vTMv(const double* v, const double* M, const int n) -{ - const char normal = 'N'; - const double one = 1; - const int inc = 1; - const double zero = 0; - double* y = new double[n]; - dgemv_(&normal, &n, &n, &one, M, &n, v, &inc, &zero, y, &inc); - double result = BlasConnector::dot(n, y, 1, v, 1); - delete[] y; - return result; -} \ No newline at end of file +template class StoChe; +#if ((defined __CUDA) || (defined __ROCM)) +template class StoChe; +#endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.h b/source/module_hamilt_pw/hamilt_stodft/sto_che.h index e2ffc1baca..3a7d2f0090 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_che.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.h @@ -1,8 +1,10 @@ #ifndef STO_CHE_H #define STO_CHE_H #include "module_base/math_chebyshev.h" +#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/module_container/ATen/kernels/blas.h" -template +template class StoChe { public: @@ -10,26 +12,58 @@ class StoChe ~StoChe(); public: - int nche = 0; ///< order of Chebyshev expansion - REAL* spolyv = nullptr; ///< coefficients of Chebyshev expansion - int method_sto = 0; ///< method for the stochastic calculation + int nche = 0; ///< order of Chebyshev expansion + REAL* spolyv = nullptr; ///< [Device] coefficients of Chebyshev expansion + REAL* spolyv_cpu = nullptr; ///< [CPU] coefficients of Chebyshev expansion + int method_sto = 0; ///< method for the stochastic calculation // Chebyshev expansion // It stores the plan of FFTW and should be initialized at the beginning of the calculation - ModuleBase::Chebyshev* p_che = nullptr; + ModuleBase::Chebyshev* p_che = nullptr; REAL emax_sto = 0.0; ///< maximum energy for normalization REAL emin_sto = 0.0; ///< minimum energy for normalization + + private: + Device* ctx = {}; + using resmem_var_op = base_device::memory::resize_memory_op; + using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op; + using delmem_var_op = base_device::memory::delete_memory_op; }; /** * @brief calculate v^T*M*v - * + * * @param v v * @param M M * @param n the dimension of v - * @return double + * @return REAL */ -double vTMv(const double* v, const double* M, const int n); +template +REAL vTMv(const REAL* v, const REAL* M, const int n) +{ + Device* ctx = {}; + base_device::DEVICE_CPU* cpu_ctx = {}; + using ct_Device = typename container::PsiToContainer::type; + const char normal = 'N'; + const REAL one = 1; + const int inc = 1; + const REAL zero = 0; + REAL* y = nullptr; + base_device::memory::resize_memory_op()(ctx, y, n); + hsolver::gemv_op()(ctx, normal, n, n, &one, M, n, v, inc, &zero, y, inc); + REAL result = 0; + REAL* dot_device = nullptr; + base_device::memory::resize_memory_op()(ctx, dot_device, 1); + container::kernels::blas_dot()(n, y, 1, v, 1, dot_device); + base_device::memory::synchronize_memory_op()(cpu_ctx, + ctx, + &result, + dot_device, + 1); + base_device::memory::delete_memory_op()(ctx, y); + base_device::memory::delete_memory_op()(ctx, dot_device); + return result; +} #endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp index dc520deaa1..5030c2bfda 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp @@ -180,7 +180,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) { auto nroot_gauss = std::bind(&Sto_Func::nroot_gauss, &this->stofunc, std::placeholders::_1); che.calcoef_real(nroot_gauss); - tmpsto = vTMv(che.coef_real, spolyv.data(), dos_nche); + tmpsto = vTMv(che.coef_real, spolyv.data(), dos_nche); } if (PARAM.inp.nbands > 0) { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index d4d6aee52d..4983ab1a6f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -7,6 +7,8 @@ #include "module_elecstate/occupy.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_parameter/parameter.h" +#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_elecstate/kernels/elecstate_op.h" template Stochastic_Iter::Stochastic_Iter() @@ -21,15 +23,26 @@ Stochastic_Iter::~Stochastic_Iter() { } +template +void Stochastic_Iter::dot(const int& n, const Real* x, const int& incx, const Real* y, const int& incy, Real& result) +{ + Real* result_device = nullptr; + resmem_var_op()(this->ctx, result_device, 1); + container::kernels::blas_dot()(n, p_che->coef_real, 1, spolyv, 1, result_device); + syncmem_var_d2h_op()(cpu_ctx, this->ctx, &result, result_device, 1); + delmem_var_op()(this->ctx, result_device); +} + template void Stochastic_Iter::init(K_Vectors* pkv_in, ModulePW::PW_Basis_K* wfc_basis, Stochastic_WF& stowf, - StoChe& stoche, + StoChe& stoche, hamilt::HamiltSdftPW* p_hamilt_sto) { p_che = stoche.p_che; spolyv = stoche.spolyv; + spolyv_cpu = stoche.spolyv_cpu; nchip = stowf.nchip; targetne = PARAM.inp.nelec; this->pkv = pkv_in; @@ -51,47 +64,51 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, stowf.chi0->fix_k(ik); stowf.chiortho->fix_k(ik); T *wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); - for (int ig = 0; ig < npwx * nchipk; ++ig) - { - wfgout[ig] = wfgin[ig]; - } + cpymem_complex_op()(this->ctx, this->ctx, wfgout, wfgin, npwx * nchipk); + // for (int ig = 0; ig < npwx * nchipk; ++ig) + // { + // wfgout[ig] = wfgin[ig]; + // } // orthogonal part - T* sum = new T[PARAM.inp.nbands * nchipk]; + T* sum = nullptr; + resmem_complex_op()(this->ctx, sum, PARAM.inp.nbands * nchipk); char transC = 'C'; char transN = 'N'; // sum(b - zgemm_(&transC, - &transN, - &PARAM.inp.nbands, - &nchipk, - &npw, - &ModuleBase::ONE, - &psi(ik, 0, 0), - &npwx, - wfgout, - &npwx, - &ModuleBase::ZERO, - sum, - &PARAM.inp.nbands); + hsolver::gemm_op()(ctx, + transC, + transN, + PARAM.inp.nbands, + nchipk, + npw, + &ModuleBase::ONE, + &psi(ik, 0, 0), + npwx, + wfgout, + npwx, + &ModuleBase::ZERO, + sum, + PARAM.inp.nbands); Parallel_Reduce::reduce_pool(sum, PARAM.inp.nbands * nchipk); // psi -= psi * sum - zgemm_(&transN, - &transN, - &npw, - &nchipk, - &PARAM.inp.nbands, - &ModuleBase::NEG_ONE, - &psi(ik, 0, 0), - &npwx, - sum, - &PARAM.inp.nbands, - &ModuleBase::ONE, - wfgout, - &npwx); - delete[] sum; + hsolver::gemm_op()(ctx, + transN, + transN, + npw, + nchipk, + PARAM.inp.nbands, + &ModuleBase::NEG_ONE, + &psi(ik, 0, 0), + npwx, + sum, + PARAM.inp.nbands, + &ModuleBase::ONE, + wfgout, + npwx); + delmem_complex_op()(this->ctx, sum); } } @@ -186,17 +203,25 @@ void Stochastic_Iter::check_precision(const double ref, const double double error = 0; if (this->method == 1) { - error = p_che->coef_real[p_che->norder - 1] * spolyv[p_che->norder - 1]; + Real last_coef = 0; + Real last_spolyv = 0; + syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, &last_coef, &p_che->coef_real[p_che->norder - 1], 1); + syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, &last_spolyv, &spolyv[p_che->norder - 1], 1); + error = last_coef * last_spolyv; } else { const int norder = p_che->norder; - double last_coef = p_che->coef_real[norder - 1]; - double last_spolyv = spolyv[norder * norder - 1]; - error = last_coef - * (BlasConnector::dot(norder, p_che->coef_real, 1, spolyv + norder * (norder - 1), 1) - + BlasConnector::dot(norder, p_che->coef_real, 1, spolyv + norder - 1, norder) - - last_coef * last_spolyv); + // double last_coef = p_che->coef_real[norder - 1]; + // double last_spolyv = spolyv[norder * norder - 1]; + Real last_coef = 0; + Real last_spolyv = 0; + syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, &last_coef, &p_che->coef_real[norder - 1], 1); + syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, &last_spolyv, &spolyv[norder * norder - 1], 1); + Real dot1 = 0, dot2 = 0; + this->dot(norder, p_che->coef_real, 1, spolyv + norder * (norder - 1), 1, dot1); + this->dot(norder, p_che->coef_real, 1, spolyv + norder - 1, norder, dot2); + error = last_coef * (dot1 + dot2 - last_coef * last_spolyv); } #ifdef __MPI @@ -329,11 +354,11 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& { if (this->method == 1) { - ModuleBase::GlobalFunc::ZEROS(spolyv, norder); + ModuleBase::GlobalFunc::ZEROS(spolyv_cpu, norder); } else { - ModuleBase::GlobalFunc::ZEROS(spolyv, norder * norder); + setmem_var_op()(this->ctx, spolyv, 0, norder * norder); } } T* pchi; @@ -358,21 +383,27 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& p_che->tracepolyA(hchi_norm, pchi, npw, npwx, nchip_ik); for (int i = 0; i < norder; ++i) { - spolyv[i] += p_che->polytrace[i] * this->pkv->wk[ik]; + spolyv_cpu[i] += p_che->polytrace[i] * this->pkv->wk[ik]; + } + if(ik == this->pkv->get_nks() - 1) + { + syncmem_var_h2d_op()(this->ctx, cpu_ctx, spolyv, spolyv_cpu, norder); } } else { p_che->calpolyvec_complex(hchi_norm, pchi, stowf.chiallorder[ik].get_pointer(), npw, npwx, nchip_ik); - double* vec_all = (double*)stowf.chiallorder[ik].get_pointer(); - char trans = 'T'; - char normal = 'N'; - double one = 1; - int LDA = npwx * nchip_ik * 2; - int M = npwx * nchip_ik * 2; // Do not use kv.ngk[ik] - int N = norder; - double kweight = this->pkv->wk[ik]; - dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv, &N); + const Real* vec_all = (Real*)stowf.chiallorder[ik].get_pointer(); + const char trans = 'T'; + const char normal = 'N'; + const Real one = 1; + const int LDA = npwx * nchip_ik * 2; + const int M = npwx * nchip_ik * 2; // Do not use kv.ngk[ik] + const int N = norder; + const Real kweight = this->pkv->wk[ik]; + + hsolver::gemm_op()(this->ctx, trans, normal, N, N, M, &kweight, vec_all, LDA, vec_all, LDA, &one, spolyv, N); + // dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv, &N); } ModuleBase::timer::tick("Stochastic_Iter", "calPn"); return; @@ -391,13 +422,13 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) // Note: spolyv contains kv.wk[ik] auto nfd = std::bind(&Sto_Func::nfd, &this->stofunc, std::placeholders::_1); p_che->calcoef_real(nfd); - sto_ne = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1); + this->dot(norder, p_che->coef_real, 1, spolyv, 1, sto_ne); } else { auto nroot_fd = std::bind(&Sto_Func::nroot_fd, &this->stofunc, std::placeholders::_1); p_che->calcoef_real(nroot_fd); - sto_ne = vTMv(p_che->coef_real, spolyv, norder); + sto_ne = vTMv(p_che->coef_real, spolyv, norder); } if (PARAM.inp.nbands > 0) { @@ -434,7 +465,7 @@ void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) template void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, - elecstate::ElecState* pes, + elecstate::ElecStatePW* pes, hamilt::Hamilt* pHamilt, ModulePW::PW_Basis_K* wfc_basis) { @@ -445,18 +476,18 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, const int norder = p_che->norder; //---------------cal demet----------------------- - double stodemet; + Real stodemet; if (this->method == 1) { auto nfdlnfd = std::bind(&Sto_Func::nfdlnfd, &this->stofunc, std::placeholders::_1); p_che->calcoef_real(nfdlnfd); - stodemet = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1); + this->dot(norder, p_che->coef_real, 1, spolyv, 1, stodemet); } else { auto nroot_fdlnfd = std::bind(&Sto_Func::n_root_fdlnfd, &this->stofunc, std::placeholders::_1); p_che->calcoef_real(nroot_fdlnfd); - stodemet = -vTMv(p_che->coef_real, spolyv, norder); + stodemet = -vTMv(p_che->coef_real, spolyv, norder); } if (PARAM.inp.nbands > 0) @@ -486,7 +517,7 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, { auto nxfd = std::bind(&Sto_Func::nxfd, &this->stofunc, std::placeholders::_1); p_che->calcoef_real(nxfd); - sto_eband = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1); + this->dot(norder, p_che->coef_real, 1, spolyv, 1, sto_eband); } else { @@ -500,17 +531,18 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, } const int npw = this->pkv->ngk[ik]; const double kweight = this->pkv->wk[ik]; - T* hshchi = new T[nchip_ik * npwx]; + T* hshchi = nullptr; + resmem_complex_op()(this->ctx, hshchi, nchip_ik * npwx); T* tmpin = stowf.shchi->get_pointer(); T* tmpout = hshchi; p_hamilt_sto->hPsi(tmpin, tmpout, nchip_ik); for (int ichi = 0; ichi < nchip_ik; ++ichi) { - sto_eband += kweight * ModuleBase::GlobalFunc::ddot_real(npw, tmpin, tmpout, false); + sto_eband += kweight * p_che->ddot_real(tmpin, tmpout, npw); tmpin += npwx; tmpout += npwx; } - delete[] hshchi; + delmem_complex_op()(this->ctx, hshchi); } } #ifdef __MPI @@ -526,7 +558,8 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, double sto_ne = 0; ModuleBase::GlobalFunc::ZEROS(sto_rho, nrxx); - T* porter = new T[nrxx]; + T* porter = nullptr; + resmem_complex_op()(this->ctx, porter, nrxx); double out2; double* ksrho = nullptr; @@ -534,25 +567,38 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, { ksrho = new double[nrxx]; ModuleBase::GlobalFunc::DCOPY(pes->charge->rho[0], ksrho, nrxx); - ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx); + setmem_var_op()(this->ctx, pes->rho[0], 0, nrxx); + // ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx); } for (int ik = 0; ik < this->pkv->get_nks(); ++ik) { const int nchip_ik = nchip[ik]; + int current_spin = 0; + if (PARAM.inp.nspin == 2) + { + current_spin = this->pkv->isk[ik]; + } stowf.shchi->fix_k(ik); T* tmpout = stowf.shchi->get_pointer(); for (int ichi = 0; ichi < nchip_ik; ++ichi) { - wfc_basis->recip2real(tmpout, porter, ik); - for (int ir = 0; ir < nrxx; ++ir) - { - pes->charge->rho[0][ir] += norm(porter[ir]) * this->pkv->wk[ik]; - } + wfc_basis->recip_to_real(this->ctx, tmpout, porter, ik); + const auto w1 = static_cast(this->pkv->wk[ik]); + elecstate::elecstate_pw_op()(this->ctx, current_spin, nrxx, w1, pes->rho, porter); + // for (int ir = 0; ir < nrxx; ++ir) + // { + // pes->charge->rho[0][ir] += norm(porter[ir]) * this->pkv->wk[ik]; + // } tmpout += npwx; } } - delete[] porter; + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { + for (int ii = 0; ii < PARAM.inp.nspin; ii++) { + castmem_var_d2h_op()(this->cpu_ctx, this->ctx, pes->charge->rho[ii], pes->rho[ii], nrxx); + } + } + delmem_complex_op()(this->ctx, porter); #ifdef __MPI // temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho pes->charge->rho_mpi(); @@ -630,20 +676,19 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WFmethod == 2) { - char transa = 'N'; - T one = 1; - int inc = 1; - T zero = 0; - int LDA = npwx * nchip[ik]; - int M = npwx * nchip[ik]; - int N = p_che->norder; - T* coef_real = new T[p_che->norder]; - for (int i = 0; i < p_che->norder; ++i) - { - coef_real[i] = p_che->coef_real[i]; - } - zgemv_(&transa, &M, &N, &one, stowf.chiallorder[ik].get_pointer(), &LDA, coef_real, &inc, &zero, out, &inc); - delete[] coef_real; + const char transa = 'N'; + const T one = 1; + const int inc = 1; + const T zero = 0; + const int LDA = npwx * nchip[ik]; + const int M = npwx * nchip[ik]; + const int N = p_che->norder; + T* coef_real = nullptr; + resmem_complex_op()(this->ctx, coef_real, N); + castmem_d2z_op()(this->ctx, this->ctx, coef_real, p_che->coef_real, p_che->norder); + gemv_op()(this->ctx, transa, M, N, &one, stowf.chiallorder[ik].get_pointer(), LDA, coef_real, inc, &zero, out, inc); + // zgemv_(&transa, &M, &N, &one, stowf.chiallorder[ik].get_pointer(), &LDA, coef_real, &inc, &zero, out, &inc); + delmem_complex_op()(this->ctx, coef_real); } else { @@ -661,3 +706,6 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WF, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class Stochastic_Iter, base_device::DEVICE_GPU>; +#endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h index 2947f98e7f..5cd0c3ec45 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h @@ -1,7 +1,7 @@ #ifndef STO_ITER_H #define STO_ITER_H #include "module_base/math_chebyshev.h" -#include "module_elecstate/elecstate.h" +#include "module_elecstate/elecstate_pw.h" #include "module_hamilt_general/hamilt.h" #include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h" #include "module_psi/psi.h" @@ -20,7 +20,8 @@ template class Stochastic_Iter { - + private: + using Real = typename GetTypeReal::type; public: // constructor and deconstructor Stochastic_Iter(); @@ -40,11 +41,11 @@ class Stochastic_Iter void init(K_Vectors* pkv_in, ModulePW::PW_Basis_K* wfc_basis, Stochastic_WF& stowf, - StoChe& stoche, + StoChe& stoche, hamilt::HamiltSdftPW* p_hamilt_sto); void sum_stoband(Stochastic_WF& stowf, - elecstate::ElecState* pes, + elecstate::ElecStatePW* pes, hamilt::Hamilt* pHamilt, ModulePW::PW_Basis_K* wfc_basis); @@ -58,7 +59,7 @@ class Stochastic_Iter void check_precision(const double ref, const double thr, const std::string info); - ModuleBase::Chebyshev* p_che = nullptr; + ModuleBase::Chebyshev* p_che = nullptr; Sto_Func stofunc; hamilt::HamiltSdftPW* p_hamilt_sto = nullptr; @@ -66,7 +67,8 @@ class Stochastic_Iter double mu0; // chemical potential; unit in Ry bool change; double targetne; - double* spolyv = nullptr; + Real* spolyv = nullptr; //[Device] coefficients of Chebyshev expansion + Real* spolyv_cpu = nullptr; //[CPU] coefficients of Chebyshev expansion public: int* nchip = nullptr; @@ -85,6 +87,28 @@ class Stochastic_Iter private: K_Vectors* pkv; + /** + * @brief return cpu dot result + * @param x [Device] + * @param y [Device] + * @param result [CPU] dot result + */ + void dot(const int& n, const Real* x, const int& incx, const Real* y, const int& incy, Real& result); + private: + const Device* ctx = {}; + const base_device::DEVICE_CPU* cpu_ctx = {}; + using ct_Device = typename container::PsiToContainer::type; + using setmem_var_op = base_device::memory::set_memory_op; + using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op; + using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op; + using cpymem_complex_op = base_device::memory::synchronize_memory_op; + using resmem_var_op = base_device::memory::resize_memory_op; + using delmem_var_op = base_device::memory::delete_memory_op; + using resmem_complex_op = base_device::memory::resize_memory_op; + using delmem_complex_op = base_device::memory::delete_memory_op; + using castmem_d2z_op = base_device::memory::cast_memory_op; + using castmem_var_d2h_op = base_device::memory::cast_memory_op; + using gemv_op = hsolver::gemv_op; }; #endif // Eelectrons_Iter diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 1f596b06db..7ad8fdcc36 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -383,3 +383,6 @@ void Stochastic_WF::sync_chi0() } template class Stochastic_WF, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class Stochastic_WF, base_device::DEVICE_GPU>; +#endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h index e6f954f9af..a423810544 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h @@ -1,7 +1,6 @@ #ifndef STOCHASTIC_WF_H #define STOCHASTIC_WF_H -#include "module_base/module_container/ATen/tensor.h" #include "module_basis/module_pw/pw_basis_k.h" #include "module_cell/klist.h" #include "module_psi/psi.h" diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index d71e23d29b..f802162f0f 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -3,6 +3,7 @@ #include "module_base/global_function.h" #include "module_base/timer.h" #include "module_base/tool_title.h" +#include "module_base/parallel_device.h" #include "module_elecstate/module_charge/symmetry_rho.h" #include @@ -12,6 +13,7 @@ namespace hsolver template void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, + psi::Psi& psi_cpu, elecstate::ElecState* pes, ModulePW::PW_Basis_K* wfc_basis, Stochastic_WF& stowf, @@ -51,10 +53,10 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, } #ifdef __MPI - if (nbands > 0) + if (nbands > 0 && PARAM.inp.bndpar > 1) { - MPI_Bcast(&psi(ik, 0, 0), npwx * nbands, MPI_DOUBLE_COMPLEX, 0, PARAPW_WORLD); - MPI_Bcast(&(pes->ekb(ik, 0)), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); + Parallel_Common::bcast_complex(this->ctx, &psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); + MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); } #endif stoiter.orthog(ik, psi, stowf); @@ -85,9 +87,11 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, } //(5) calculate new charge density // calculate KS rho. + elecstate::ElecStatePW* pes_pw = static_cast*>(pes); + pes_pw->init_rho_data(); if (nbands > 0) { - pes->psiToRho(psi); + pes_pw->psiToRho(psi); #ifdef __MPI MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD); #endif @@ -96,11 +100,11 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, { for (int is = 0; is < this->nspin; is++) { - ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], pes->charge->nrxx); + setmem_var_op()(this->ctx, pes_pw->rho[is], 0, pes_pw->charge->nrxx); } } // calculate stochastic rho - stoiter.sum_stoband(stowf, pes, pHamilt, wfc_basis); + stoiter.sum_stoband(stowf, pes_pw, pHamilt, wfc_basis); // will do rho symmetry and energy calculation in esolver ModuleBase::timer::tick("HSolverPW_SDFT", "solve"); @@ -109,4 +113,8 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, // template class HSolverPW_SDFT, base_device::DEVICE_CPU>; template class HSolverPW_SDFT, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +// template class HSolverPW_SDFT, base_device::DEVICE_GPU>; +template class HSolverPW_SDFT, base_device::DEVICE_GPU>; +#endif } // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/hsolver_pw_sdft.h b/source/module_hsolver/hsolver_pw_sdft.h index 6fc0a39fec..c1d39a401d 100644 --- a/source/module_hsolver/hsolver_pw_sdft.h +++ b/source/module_hsolver/hsolver_pw_sdft.h @@ -8,12 +8,14 @@ namespace hsolver template class HSolverPW_SDFT : public HSolverPW { + protected: + using Real = typename GetTypeReal::type; public: HSolverPW_SDFT(K_Vectors* pkv, ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in, Stochastic_WF& stowf, - StoChe& stoche, + StoChe& stoche, hamilt::HamiltSdftPW* p_hamilt_sto, const std::string calculation_type_in, const std::string basis_type_in, @@ -45,6 +47,7 @@ class HSolverPW_SDFT : public HSolverPW void solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, + psi::Psi& psi_cpu, elecstate::ElecState* pes, ModulePW::PW_Basis_K* wfc_basis, Stochastic_WF& stowf, @@ -53,6 +56,13 @@ class HSolverPW_SDFT : public HSolverPW const bool skip_charge); Stochastic_Iter stoiter; + protected: + using setmem_complex_op = base_device::memory::set_memory_op; + using setmem_var_op = base_device::memory::set_memory_op; + using syncmem_h2d_op = base_device::memory::synchronize_memory_op; + using syncmem_d2h_op = base_device::memory::synchronize_memory_op; + using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op; + using syncmem_var_d2h_op = base_device::memory::synchronize_memory_op; }; } // namespace hsolver #endif \ No newline at end of file diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 02deb41696..c3784949ab 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -359,7 +359,9 @@ struct matrixSetToAnother template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; template struct gemv_op, base_device::DEVICE_CPU>; +template struct gemv_op; template struct gemm_op, base_device::DEVICE_CPU>; +template struct gemm_op; template struct dot_real_op, base_device::DEVICE_CPU>; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; @@ -373,7 +375,9 @@ template struct line_minimize_with_block_op, base_device::DE template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; template struct gemv_op, base_device::DEVICE_CPU>; +template struct gemv_op; template struct gemm_op, base_device::DEVICE_CPU>; +template struct gemm_op; template struct dot_real_op, base_device::DEVICE_CPU>; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; @@ -386,8 +390,6 @@ template struct line_minimize_with_block_op, base_device::D #ifdef __LCAO template struct axpy_op; -template struct gemv_op; -template struct gemm_op; template struct dot_real_op; template struct vector_mul_vector_op; template struct vector_div_constant_op; diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index 687bc38bc7..72f690d804 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -12,6 +12,7 @@ #include "module_base/global_variable.h" #include "module_hsolver/hsolver_pw.h" #include "module_hsolver/hsolver_pw_sdft.h" +#include "module_elecstate/elecstate_pw.h" #undef private #undef protected @@ -20,16 +21,49 @@ template Sto_Func::Sto_Func() { } - template class Sto_Func; -template -StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) + +template <> +elecstate::ElecStatePW, base_device::DEVICE_CPU>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in, + Charge* chg_in, + K_Vectors* pkv_in, + UnitCell* ucell_in, + pseudopot_cell_vnl* ppcell_in, + ModulePW::PW_Basis* rhodpw_in, + ModulePW::PW_Basis* rhopw_in, + ModulePW::PW_Basis_Big* bigpw_in) + : basis(wfc_basis_in) +{ +} + +template<> +elecstate::ElecStatePW, base_device::DEVICE_CPU>::~ElecStatePW() +{ +} + +template<> +void elecstate::ElecStatePW, base_device::DEVICE_CPU>::init_rho_data() +{ +} + +template<> +void elecstate::ElecStatePW, base_device::DEVICE_CPU>::psiToRho(const psi::Psi, base_device::DEVICE_CPU>& psi) +{ +} + +template<> +void elecstate::ElecStatePW, base_device::DEVICE_CPU>::cal_tau(const psi::Psi, base_device::DEVICE_CPU>& psi) +{ +} + +template +StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) { this->nche = nche; } -template -StoChe::~StoChe() +template +StoChe::~StoChe() { } @@ -51,7 +85,7 @@ template void Stochastic_Iter::init(K_Vectors* pkv_in, ModulePW::PW_Basis_K* wfc_basis, Stochastic_WF& stowf, - StoChe& stoche, + StoChe& stoche, hamilt::HamiltSdftPW* p_hamilt_sto) { this->nchip = stowf.nchip; @@ -108,7 +142,7 @@ void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) template void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, - elecstate::ElecState* pes, + elecstate::ElecStatePW* pes, hamilt::Hamilt* pHamilt, ModulePW::PW_Basis_K* wfc_basis) { @@ -136,7 +170,7 @@ Charge::~Charge(){}; class TestHSolverPW_SDFT : public ::testing::Test { public: - TestHSolverPW_SDFT() : stoche(8, 1, 0, 0) + TestHSolverPW_SDFT() : stoche(8, 1, 0, 0), elecstate_test(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr) { } ModulePW::PW_Basis_K pwbk; @@ -170,7 +204,7 @@ class TestHSolverPW_SDFT : public ::testing::Test psi::Psi> psi_test_cd; psi::Psi> psi_test_no; - elecstate::ElecState elecstate_test; + elecstate::ElecStatePW> elecstate_test; std::string method_test = "cg"; @@ -193,7 +227,7 @@ TEST_F(TestHSolverPW_SDFT, solve) int istep = 0; int iter = 0; - this->hs_d.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false); + this->hs_d.solve(&hamilt_test_d, psi_test_cd, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false); EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0); EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0); @@ -234,10 +268,12 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) elecstate_test.charge->rho = new double*[1]; elecstate_test.charge->rho[0] = new double[10]; elecstate_test.charge->nrxx = 10; + elecstate_test.rho = new double*[1]; + elecstate_test.rho[0] = new double[10]; int istep = 0; int iter = 0; - this->hs_d.solve(&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false); + this->hs_d.solve(&hamilt_test_d, psi_test_no, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false); EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); EXPECT_EQ(stowf.nbands_diag, 2); EXPECT_EQ(stowf.nbands_total, 1); @@ -251,7 +287,7 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) std::cout<<__FILE__<<__LINE__<<" "<hs_d.solve(&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true); + this->hs_d.solve(&hamilt_test_d, psi_test_no, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true); EXPECT_EQ(stowf.nbands_diag, 4); EXPECT_EQ(stowf.nbands_total, 1); EXPECT_EQ(stowf.nchi, 4);