From d54923148445a0c9f2fe3548e028a8bc0a79e697 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Mon, 25 Nov 2024 15:35:20 +0800 Subject: [PATCH] Refactor: split sum_stoband to sum_stoband and cal_storho --- .../hamilt_stodft/sto_iter.cpp | 124 ++++++++++++------ .../module_hamilt_pw/hamilt_stodft/sto_iter.h | 53 ++++++++ source/module_hsolver/hsolver_pw_sdft.cpp | 9 +- .../module_hsolver/test/test_hsolver_sdft.cpp | 7 + 4 files changed, 145 insertions(+), 48 deletions(-) diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 0f450bdc31..b9c0fe95f6 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -479,8 +479,7 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, { ModuleBase::TITLE("Stochastic_Iter", "sum_stoband"); ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband"); - int nrxx = wfc_basis->nrxx; - int npwx = wfc_basis->npwk_max; + const int npwx = wfc_basis->npwk_max; const int norder = p_che->norder; //---------------cal demet----------------------- @@ -557,33 +556,53 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, MPI_Allreduce(MPI_IN_PLACE, &sto_eband, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif pes->f_en.eband += sto_eband; - //---------------------cal rho------------------------- - double* sto_rho = new double[nrxx]; + ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband"); +} - double dr3 = GlobalC::ucell.omega / wfc_basis->nxyz; - double tmprho, tmpne; - T outtem; - double sto_ne = 0; - ModuleBase::GlobalFunc::ZEROS(sto_rho, nrxx); +template +void Stochastic_Iter::cal_storho(Stochastic_WF& stowf, + elecstate::ElecStatePW* pes, + ModulePW::PW_Basis_K* wfc_basis) +{ + ModuleBase::TITLE("Stochastic_Iter", "cal_storho"); + ModuleBase::timer::tick("Stochastic_Iter", "cal_storho"); + //---------------------cal rho------------------------- + const int nrxx = wfc_basis->nrxx; + const int npwx = wfc_basis->npwk_max; + const int nspin = PARAM.inp.nspin; T* porter = nullptr; resmem_complex_op()(this->ctx, porter, nrxx); - double out2; - double* ksrho = nullptr; - if (PARAM.inp.nbands > 0 && GlobalV::MY_STOGROUP == 0) + std::vector sto_rho(nspin); + for(int is = 0; is < nspin; ++is) + { + sto_rho[is] = pes->charge->rho[is]; + } + std::vector _tmprho; + if (PARAM.inp.nbands > 0) { - ksrho = new double[nrxx]; - ModuleBase::GlobalFunc::DCOPY(pes->charge->rho[0], ksrho, nrxx); - setmem_var_op()(this->ctx, pes->rho[0], 0, nrxx); - // ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx); + // If there are KS orbitals, we need to allocate another memory for sto_rho + _tmprho.resize(nrxx * nspin); + for(int is = 0; is < nspin; ++is) + { + sto_rho[is] = _tmprho.data() + is * nrxx; + } } + // pes->rho is a device memory, and when using cpu and double, we donot need to allocate memory for pes->rho + if (PARAM.inp.device != "gpu" && PARAM.inp.precision != "single") { + pes->rho = reinterpret_cast(sto_rho.data()); + } + for (int is = 0; is < nspin; is++) + { + setmem_var_op()(this->ctx, pes->rho[is], 0, nrxx); + } for (int ik = 0; ik < this->pkv->get_nks(); ++ik) { const int nchip_ik = nchip[ik]; int current_spin = 0; - if (PARAM.inp.nspin == 2) + if (nspin == 2) { current_spin = this->pkv->isk[ik]; } @@ -602,27 +621,50 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, } } if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { - for (int ii = 0; ii < PARAM.inp.nspin; ii++) { - castmem_var_d2h_op()(this->cpu_ctx, this->ctx, pes->charge->rho[ii], pes->rho[ii], nrxx); + for(int is = 0; is < nspin; ++is) + { + castmem_var_d2h_op()(this->cpu_ctx, this->ctx, sto_rho[is], pes->rho[is], nrxx); } } + else + { + // We need to set pes->rho back to the original value + pes->rho = reinterpret_cast(pes->charge->rho); + } + delmem_complex_op()(this->ctx, porter); #ifdef __MPI - // temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho - pes->charge->rho_mpi(); + if(GlobalV::KPAR > 1) + { + for (int is = 0; is < nspin; ++is) + { + pes->charge->reduce_diff_pools(sto_rho[is]); + } + } #endif - for (int ir = 0; ir < nrxx; ++ir) + + double sto_ne = 0; + for(int is = 0; is < nspin; ++is) { - tmprho = pes->charge->rho[0][ir] / GlobalC::ucell.omega; - sto_rho[ir] = tmprho; - sto_ne += tmprho; +#ifdef _OPENMP +#pragma omp parallel for reduction(+ : sto_ne) +#endif + for (int ir = 0; ir < nrxx; ++ir) + { + sto_rho[is][ir] /= GlobalC::ucell.omega; + sto_ne += sto_rho[is][ir]; + } } - sto_ne *= dr3; + + sto_ne *= GlobalC::ucell.omega / wfc_basis->nxyz; #ifdef __MPI MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); - MPI_Allreduce(MPI_IN_PLACE, sto_rho, nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); + for(int is = 0; is < nspin; ++is) + { + MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); + } #endif double factor = targetne / (KS_ne + sto_ne); if (std::abs(factor - 1) > 1e-10) @@ -635,32 +677,32 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, factor = 1; } - if (GlobalV::MY_STOGROUP == 0) + for (int is = 0; is < 1; ++is) { if (PARAM.inp.nbands > 0) { - ModuleBase::GlobalFunc::DCOPY(ksrho, pes->charge->rho[0], nrxx); +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (int ir = 0; ir < nrxx; ++ir) + { + pes->charge->rho[is][ir] += sto_rho[is][ir]; + pes->charge->rho[is][ir] *= factor; + } } else { - ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx); - } - } - - if (GlobalV::MY_STOGROUP == 0) - { - for (int is = 0; is < 1; ++is) - { +#ifdef _OPENMP +#pragma omp parallel for +#endif for (int ir = 0; ir < nrxx; ++ir) { - pes->charge->rho[is][ir] += sto_rho[ir]; pes->charge->rho[is][ir] *= factor; } } } - delete[] sto_rho; - delete[] ksrho; - ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband"); + + ModuleBase::timer::tick("Stochastic_Iter", "cal_storho"); return; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h index 5cd0c3ec45..75bcda3fe6 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h @@ -44,19 +44,72 @@ class Stochastic_Iter StoChe& stoche, hamilt::HamiltSdftPW* p_hamilt_sto); + /** + * @brief sum demet and eband energies for each k point and each band + * + * @param stowf stochastic wave function + * @param pes elecstate + * @param pHamilt hamiltonian + * @param wfc_basis wfc pw basis + */ void sum_stoband(Stochastic_WF& stowf, elecstate::ElecStatePW* pes, hamilt::Hamilt* pHamilt, ModulePW::PW_Basis_K* wfc_basis); + /** + * @brief calculate the density + * + * @param stowf stochastic wave function + * @param pes elecstate + * @param wfc_basis wfc pw basis + */ + void cal_storho(Stochastic_WF& stowf, + elecstate::ElecStatePW* pes, + ModulePW::PW_Basis_K* wfc_basis); + + /** + * @brief calculate total number of electrons + * + * @param pes elecstate + * @return double + */ double calne(elecstate::ElecState* pes); + /** + * @brief solve ne(mu) = ne_target and get chemical potential mu + * + * @param iter scf iteration index + * @param pes elecstate + */ void itermu(const int iter, elecstate::ElecState* pes); + /** + * @brief orthogonalize stochastic wave functions with KS wave functions + * + * @param ik k point index + * @param psi KS wave functions + * @param stowf stochastic wave functions + */ void orthog(const int& ik, psi::Psi& psi, Stochastic_WF& stowf); + /** + * @brief check emax and emin + * + * @param ik k point index + * @param istep ion step index + * @param iter scf iteration index + * @param stowf stochastic wave functions + */ void checkemm(const int& ik, const int istep, const int iter, Stochastic_WF& stowf); + /** + * @brief check precision of Chebyshev expansion + * + * @param ref reference value + * @param thr threshold + * @param info information + */ void check_precision(const double ref, const double thr, const std::string info); ModuleBase::Chebyshev* p_che = nullptr; diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index f49cad787c..e66db52392 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -128,15 +128,10 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD); #endif } - else - { - for (int is = 0; is < this->nspin; is++) - { - setmem_var_op()(this->ctx, pes_pw->rho[is], 0, pes_pw->charge->nrxx); - } - } + // calculate stochastic rho stoiter.sum_stoband(stowf, pes_pw, pHamilt, wfc_basis); + stoiter.cal_storho(stowf, pes_pw, wfc_basis); // will do rho symmetry and energy calculation in esolver ModuleBase::timer::tick("HSolverPW_SDFT", "solve"); diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index 7b234e953e..05bda81301 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -151,6 +151,13 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, return; } +template +void Stochastic_Iter::cal_storho(Stochastic_WF& stowf, + elecstate::ElecStatePW* pes, + ModulePW::PW_Basis_K* wfc_basis) +{ +} + Charge::Charge(){}; Charge::~Charge(){};