diff --git a/source/module_elecstate/elecstate_pw_sdft.cpp b/source/module_elecstate/elecstate_pw_sdft.cpp index bb840b0cfc..c36513c838 100644 --- a/source/module_elecstate/elecstate_pw_sdft.cpp +++ b/source/module_elecstate/elecstate_pw_sdft.cpp @@ -1,41 +1,47 @@ #include "./elecstate_pw_sdft.h" + +#include "module_base/global_function.h" #include "module_base/global_variable.h" -#include "module_parameter/parameter.h" #include "module_base/timer.h" -#include "module_base/global_function.h" #include "module_hamilt_general/module_xc/xc_functional.h" +#include "module_parameter/parameter.h" namespace elecstate { - void ElecStatePW_SDFT::psiToRho(const psi::Psi>& psi) + +template +void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) +{ + ModuleBase::TITLE(this->classname, "psiToRho"); + ModuleBase::timer::tick(this->classname, "psiToRho"); + for (int is = 0; is < PARAM.inp.nspin; is++) { - ModuleBase::TITLE(this->classname, "psiToRho"); - ModuleBase::timer::tick(this->classname, "psiToRho"); - for(int is=0; is < PARAM.inp.nspin; is++) - { - ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); - if (XC_Functional::get_func_type() == 3) - { - ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); - } - } - - if(GlobalV::MY_STOGROUP == 0) - { - this->calEBand(); + ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); + if (XC_Functional::get_func_type() == 3) + { + ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); + } + } + + if (GlobalV::MY_STOGROUP == 0) + { + this->calEBand(); - for(int is=0; ischarge->rho[is], this->charge->nrxx); - } + for (int is = 0; is < PARAM.inp.nspin; is++) + { + ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); + } - for (int ik = 0; ik < psi.get_nk(); ++ik) - { - psi.fix_k(ik); - this->updateRhoK(psi); - } - this->parallelK(); + for (int ik = 0; ik < psi.get_nk(); ++ik) + { + psi.fix_k(ik); + this->updateRhoK(psi); } - ModuleBase::timer::tick(this->classname, "psiToRho"); - return; + this->parallelK(); } -} \ No newline at end of file + ModuleBase::timer::tick(this->classname, "psiToRho"); + return; +} + +// template class ElecStatePW_SDFT, base_device::DEVICE_CPU>; +template class ElecStatePW_SDFT, base_device::DEVICE_CPU>; +} // namespace elecstate \ No newline at end of file diff --git a/source/module_elecstate/elecstate_pw_sdft.h b/source/module_elecstate/elecstate_pw_sdft.h index 2dd312ac08..eaf9215390 100644 --- a/source/module_elecstate/elecstate_pw_sdft.h +++ b/source/module_elecstate/elecstate_pw_sdft.h @@ -3,22 +3,24 @@ #include "elecstate_pw.h" namespace elecstate { - class ElecStatePW_SDFT : public ElecStatePW> +template +class ElecStatePW_SDFT : public ElecStatePW +{ + public: + ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in, + Charge* chg_in, + K_Vectors* pkv_in, + UnitCell* ucell_in, + pseudopot_cell_vnl* ppcell_in, + ModulePW::PW_Basis* rhodpw_in, + ModulePW::PW_Basis* rhopw_in, + ModulePW::PW_Basis_Big* bigpw_in) + : ElecStatePW(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in) { - public: - ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in, - Charge* chg_in, - K_Vectors* pkv_in, - UnitCell* ucell_in, - pseudopot_cell_vnl* ppcell_in, - ModulePW::PW_Basis* rhodpw_in, - ModulePW::PW_Basis* rhopw_in, - ModulePW::PW_Basis_Big* bigpw_in) - : ElecStatePW(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in) - { - this->classname = "ElecStatePW_SDFT"; - } - virtual void psiToRho(const psi::Psi>& psi) override; - }; -} + this->classname = "ElecStatePW_SDFT"; + } + virtual void psiToRho(const psi::Psi& psi) override; +}; +} // namespace elecstate #endif \ No newline at end of file diff --git a/source/module_esolver/esolver.cpp b/source/module_esolver/esolver.cpp index b0247fd63e..17671ada74 100644 --- a/source/module_esolver/esolver.cpp +++ b/source/module_esolver/esolver.cpp @@ -153,6 +153,17 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell) return new ESolver_KS_PW, base_device::DEVICE_CPU>(); } } + else if (esolver_type == "sdft_pw") + { + // if (PARAM.inp.precision == "single") + // { + // return new ESolver_SDFT_PW, base_device::DEVICE_CPU>(); + // } + // else + // { + return new ESolver_SDFT_PW, base_device::DEVICE_CPU>(); + // } + } #ifdef __LCAO else if (esolver_type == "ksdft_lip") { @@ -230,10 +241,6 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell) return p_esolver_lr; } #endif - else if (esolver_type == "sdft_pw") - { - return new ESolver_SDFT_PW(); - } else if(esolver_type == "ofdft") { return new ESolver_OF(); diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index f79454891f..38998041e9 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -441,6 +441,9 @@ void ESolver_KS_PW::update_pot(const int istep, const int iter) } this->pelec->pot->update_from_charge(this->pelec->charge, &GlobalC::ucell); this->pelec->f_en.descf = this->pelec->cal_delta_escf(); +#ifdef __MPI + MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, PARAPW_WORLD); +#endif } else { diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 5b79d16818..d7e5a5b317 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -1,6 +1,5 @@ #include "esolver_sdft_pw.h" -#include "module_parameter/parameter.h" #include "module_base/memory.h" #include "module_base/timer.h" #include "module_elecstate/elecstate_pw_sdft.h" @@ -11,6 +10,7 @@ #include "module_io/cube_io.h" #include "module_io/output_log.h" #include "module_io/write_istate_info.h" +#include "module_parameter/parameter.h" #include #include @@ -29,35 +29,38 @@ namespace ModuleESolver { -ESolver_SDFT_PW::ESolver_SDFT_PW() +template +ESolver_SDFT_PW::ESolver_SDFT_PW() : stoche(PARAM.inp.nche_sto, PARAM.inp.method_sto, PARAM.inp.emax_sto, PARAM.inp.emin_sto) { - classname = "ESolver_SDFT_PW"; - basisname = "PW"; + this->classname = "ESolver_SDFT_PW"; + this->basisname = "PW"; } -ESolver_SDFT_PW::~ESolver_SDFT_PW() +template +ESolver_SDFT_PW::~ESolver_SDFT_PW() { } -void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) +template +void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) { // 1) initialize parameters from int Input class this->nche_sto = inp.nche_sto; this->method_sto = inp.method_sto; // 2) run "before_all_runners" in ESolver_KS - ESolver_KS::before_all_runners(inp, ucell); + ESolver_KS::before_all_runners(inp, ucell); // 3) initialize the pointer for electronic states of SDFT - this->pelec = new elecstate::ElecStatePW_SDFT(pw_wfc, - &(chr), - (K_Vectors*)(&(kv)), - &ucell, - &(GlobalC::ppcell), - this->pw_rhod, - this->pw_rho, - pw_big); + this->pelec = new elecstate::ElecStatePW_SDFT(this->pw_wfc, + &(this->chr), + &(this->kv), + &ucell, + &(GlobalC::ppcell), + this->pw_rhod, + this->pw_rho, + this->pw_big); // 4) inititlize the charge density. this->pelec->charge->allocate(PARAM.inp.nspin); @@ -66,23 +69,22 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) // 5) initialize the potential. if (this->pelec->pot == nullptr) { - this->pelec->pot = new elecstate::Potential(pw_rhod, - pw_rho, + this->pelec->pot = new elecstate::Potential(this->pw_rhod, + this->pw_rho, &ucell, &(GlobalC::ppcell.vloc), - &(sf), + &(this->sf), &(this->pelec->f_en.etxc), &(this->pelec->f_en.vtxc)); - GlobalTemp::veff = &(this->pelec->pot->get_effective_v()); } // 6) prepare some parameters for electronic wave functions initilization - this->p_wf_init = new psi::WFInit>(PARAM.inp.init_wfc, - PARAM.inp.ks_solver, - PARAM.inp.basis_type, - PARAM.inp.psi_initializer, - &this->wf, - this->pw_wfc); + this->p_wf_init = new psi::WFInit(PARAM.inp.init_wfc, + PARAM.inp.ks_solver, + PARAM.inp.basis_type, + PARAM.inp.psi_initializer, + &this->wf, + this->pw_wfc); // 7) set occupatio, redundant? if (PARAM.inp.ocp) { @@ -93,74 +95,79 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) this->Init_GlobalC(inp, ucell, GlobalC::ppcell); // temporary // 9) initialize the stochastic wave functions - stowf.init(&kv, pw_wfc->npwk_max); + this->stowf.init(&this->kv, this->pw_wfc->npwk_max); if (inp.nbands_sto != 0) { if (inp.initsto_ecut < inp.ecutwfc) { - Init_Sto_Orbitals(this->stowf, inp.seed_sto); + this->stowf.init_sto_orbitals(inp.seed_sto); } else { - Init_Sto_Orbitals_Ecut(this->stowf, inp.seed_sto, kv, *pw_wfc, inp.initsto_ecut); + this->stowf.init_sto_orbitals_Ecut(inp.seed_sto, this->kv, *this->pw_wfc, inp.initsto_ecut); } } else { - Init_Com_Orbitals(this->stowf); + this->stowf.init_com_orbitals(); } if (this->method_sto == 2) { - stowf.allocate_chiallorder(this->nche_sto); + this->stowf.allocate_chiallorder(this->nche_sto); } + this->stowf.sync_chi0(); + // 10) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}> size_t size = stowf.chi0->size(); - - this->stowf.shchi = new psi::Psi>(kv.get_nks(), stowf.nchip_max, wf.npwx, kv.ngk.data()); - - ModuleBase::Memory::record("SDFT::shchi", size * sizeof(std::complex)); + this->stowf.shchi + = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->wf.npwx, this->kv.ngk.data()); + ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T)); if (PARAM.inp.nbands > 0) { this->stowf.chiortho - = new psi::Psi>(kv.get_nks(), stowf.nchip_max, wf.npwx, kv.ngk.data()); - ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(std::complex)); + = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->wf.npwx, this->kv.ngk.data()); + ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T)); } return; } -void ESolver_SDFT_PW::before_scf(const int istep) +template +void ESolver_SDFT_PW::before_scf(const int istep) { - ESolver_KS_PW::before_scf(istep); + ESolver_KS_PW::before_scf(istep); delete reinterpret_cast*>(this->p_hamilt); - this->p_hamilt = new hamilt::HamiltSdftPW>(this->pelec->pot, - this->pw_wfc, - &this->kv, - PARAM.globalv.npol, - &this->stoche.emin_sto, - &this->stoche.emax_sto); - this->p_hamilt_sto = static_cast>*>(this->p_hamilt); + this->p_hamilt = new hamilt::HamiltSdftPW(this->pelec->pot, + this->pw_wfc, + &this->kv, + PARAM.globalv.npol, + &this->stoche.emin_sto, + &this->stoche.emax_sto); + this->p_hamilt_sto = static_cast*>(this->p_hamilt); if (istep > 0 && PARAM.inp.nbands_sto != 0 && PARAM.inp.initsto_freq > 0 && istep % PARAM.inp.initsto_freq == 0) { - Update_Sto_Orbitals(this->stowf, PARAM.inp.seed_sto); + this->stowf.update_sto_orbitals(PARAM.inp.seed_sto); } } -void ESolver_SDFT_PW::iter_finish(int& iter) +template +void ESolver_SDFT_PW::iter_finish(int& iter) { // call iter_finish() of ESolver_KS - ESolver_KS>::iter_finish(iter); + ESolver_KS::iter_finish(iter); } -void ESolver_SDFT_PW::after_scf(const int istep) +template +void ESolver_SDFT_PW::after_scf(const int istep) { // 1) call after_scf() of ESolver_KS_PW - ESolver_KS_PW>::after_scf(istep); + ESolver_KS_PW::after_scf(istep); } -void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) +template +void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) { // reset energy this->pelec->f_en.eband = 0.0; @@ -169,45 +176,37 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) // be careful that istep start from 0 and iter start from 1 if (istep == 0 && iter == 1) { - hsolver::DiagoIterAssist>::need_subspace = false; + hsolver::DiagoIterAssist::need_subspace = false; } else { - hsolver::DiagoIterAssist>::need_subspace = true; + hsolver::DiagoIterAssist::need_subspace = true; } - hsolver::DiagoIterAssist>::PW_DIAG_THR = ethr; + hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - hsolver::DiagoIterAssist>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; + hsolver::DiagoIterAssist::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; // hsolver only exists in this function - hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj( - &this->kv, - this->pw_wfc, - &this->wf, - this->stowf, - this->stoche, - this->p_hamilt_sto, - PARAM.inp.calculation, - PARAM.inp.basis_type, - PARAM.inp.ks_solver, - PARAM.inp.use_paw, - PARAM.globalv.use_uspp, - PARAM.inp.nspin, - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, - hsolver::DiagoIterAssist>::need_subspace, - this->init_psi); - - hsolver_pw_sdft_obj.solve(this->p_hamilt, - this->psi[0], - this->pelec, - pw_wfc, - this->stowf, - istep, - iter, - false); + hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, + this->pw_wfc, + &this->wf, + this->stowf, + this->stoche, + this->p_hamilt_sto, + PARAM.inp.calculation, + PARAM.inp.basis_type, + PARAM.inp.ks_solver, + PARAM.inp.use_paw, + PARAM.globalv.use_uspp, + PARAM.inp.nspin, + hsolver::DiagoIterAssist::SCF_ITER, + hsolver::DiagoIterAssist::PW_DIAG_NMAX, + hsolver::DiagoIterAssist::PW_DIAG_THR, + hsolver::DiagoIterAssist::need_subspace, + this->init_psi); + + hsolver_pw_sdft_obj.solve(this->p_hamilt, this->psi[0], this->pelec, this->pw_wfc, this->stowf, istep, iter, false); this->init_psi = true; // set_diagethr need it @@ -218,7 +217,7 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) Symmetry_rho srho; for (int is = 0; is < PARAM.inp.nspin; is++) { - srho.begin(is, *(this->pelec->charge), pw_rho, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::ucell.symm); } this->pelec->f_en.deband = this->pelec->cal_delta_eband(); } @@ -231,45 +230,70 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) } #endif } +#ifdef __MPI + MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, PARAPW_WORLD); +#endif } -double ESolver_SDFT_PW::cal_energy() +template +double ESolver_SDFT_PW::cal_energy() { return this->pelec->f_en.etot; } -void ESolver_SDFT_PW::cal_force(ModuleBase::matrix& force) +template +void ESolver_SDFT_PW::cal_force(ModuleBase::matrix& force) { Sto_Forces ff(GlobalC::ucell.nat); - ff.cal_stoforce(force, *this->pelec, pw_rho, &GlobalC::ucell.symm, &sf, &kv, pw_wfc, this->psi, this->stowf); + ff.cal_stoforce(force, + *this->pelec, + this->pw_rho, + &GlobalC::ucell.symm, + &this->sf, + &this->kv, + this->pw_wfc, + this->psi, + this->stowf); } -void ESolver_SDFT_PW::cal_stress(ModuleBase::matrix& stress) +template +void ESolver_SDFT_PW::cal_stress(ModuleBase::matrix& stress) { Sto_Stress_PW ss; ss.cal_stress(stress, *this->pelec, - pw_rho, + this->pw_rho, &GlobalC::ucell.symm, - &sf, - &kv, - pw_wfc, + &this->sf, + &this->kv, + this->pw_wfc, this->psi, this->stowf, - pelec->charge, + this->pelec->charge, &GlobalC::ppcell, GlobalC::ucell); } -void ESolver_SDFT_PW::after_all_runners() +template +void ESolver_SDFT_PW::after_all_runners() +{ + GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl; + GlobalV::ofs_running << std::setprecision(16); + GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl; + GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl; + ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints)); +} + +template <> +void ESolver_SDFT_PW, base_device::DEVICE_CPU>::after_all_runners() { GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl; GlobalV::ofs_running << std::setprecision(16); GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl; GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl; - ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, kv, &(GlobalC::Pkpoints)); + ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints)); if (this->method_sto == 2) { @@ -277,13 +301,7 @@ void ESolver_SDFT_PW::after_all_runners() } if (PARAM.inp.out_dos) { - Sto_DOS sto_dos(this->pw_wfc, - &this->kv, - this->pelec, - this->psi, - this->p_hamilt, - this->stoche, - &stowf); + Sto_DOS sto_dos(this->pw_wfc, &this->kv, this->pelec, this->psi, this->p_hamilt, this->stoche, &stowf); sto_dos.decide_param(PARAM.inp.dos_nche, PARAM.inp.emin_sto, PARAM.inp.emax_sto, @@ -318,7 +336,8 @@ void ESolver_SDFT_PW::after_all_runners() } } -void ESolver_SDFT_PW::others(const int istep) +template +void ESolver_SDFT_PW::others(const int istep) { ModuleBase::TITLE("ESolver_SDFT_PW", "others"); @@ -328,13 +347,14 @@ void ESolver_SDFT_PW::others(const int istep) } else { - ModuleBase::WARNING_QUIT("ESolver_SDFT_PW::others", "CALCULATION type not supported"); + ModuleBase::WARNING_QUIT("ESolver_SDFT_PW::others", "CALCULATION type not supported"); } return; } -void ESolver_SDFT_PW::nscf() +template +void ESolver_SDFT_PW::nscf() { ModuleBase::TITLE("ESolver_SDFT_PW", "nscf"); ModuleBase::timer::tick("ESolver_SDFT_PW", "nscf"); @@ -356,4 +376,7 @@ void ESolver_SDFT_PW::nscf() ModuleBase::timer::tick("ESolver_SDFT_PW", "nscf"); return; } + +// template class ESolver_SDFT_PW, base_device::DEVICE_CPU>; +template class ESolver_SDFT_PW, base_device::DEVICE_CPU>; } // namespace ModuleESolver diff --git a/source/module_esolver/esolver_sdft_pw.h b/source/module_esolver/esolver_sdft_pw.h index e3cdba3ab0..8f544febe6 100644 --- a/source/module_esolver/esolver_sdft_pw.h +++ b/source/module_esolver/esolver_sdft_pw.h @@ -2,15 +2,16 @@ #define ESOLVER_SDFT_PW_H #include "esolver_ks_pw.h" +#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h" +#include "module_hamilt_pw/hamilt_stodft/sto_che.h" #include "module_hamilt_pw/hamilt_stodft/sto_iter.h" #include "module_hamilt_pw/hamilt_stodft/sto_wf.h" -#include "module_hamilt_pw/hamilt_stodft/sto_che.h" -#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h" namespace ModuleESolver { -class ESolver_SDFT_PW : public ESolver_KS_PW> +template +class ESolver_SDFT_PW : public ESolver_KS_PW { public: ESolver_SDFT_PW(); @@ -25,9 +26,9 @@ class ESolver_SDFT_PW : public ESolver_KS_PW> void cal_stress(ModuleBase::matrix& stress) override; public: - Stochastic_WF stowf; + Stochastic_WF stowf; StoChe stoche; - hamilt::HamiltSdftPW>* p_hamilt_sto = nullptr; + hamilt::HamiltSdftPW* p_hamilt_sto = nullptr; protected: virtual void before_scf(const int istep) override; @@ -50,13 +51,4 @@ class ESolver_SDFT_PW : public ESolver_KS_PW> }; } // namespace ModuleESolver - -// temporary setting: removed GlobalC but not breaking design philosophy -namespace GlobalTemp -{ - -extern const ModuleBase::matrix* veff; - -} - #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp index ee0a5fe1db..dc520deaa1 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp @@ -1,8 +1,8 @@ #include "sto_dos.h" -#include "module_parameter/parameter.h" #include "module_base/timer.h" #include "module_base/tool_title.h" +#include "module_parameter/parameter.h" #include "sto_tool.h" Sto_DOS::~Sto_DOS() { @@ -14,7 +14,7 @@ Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in, psi::Psi>* p_psi_in, hamilt::Hamilt>* p_hamilt_in, StoChe& stoche, - Stochastic_WF* p_stowf_in) + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf_in) { this->p_wfcpw = p_wfcpw_in; this->p_kv = p_kv_in; @@ -38,13 +38,7 @@ void Sto_DOS::decide_param(const int& dos_nche, const double& dos_scale) { this->dos_nche = dos_nche; - check_che(this->dos_nche, - emin_sto, - emax_sto, - this->nbands_sto, - this->p_kv, - this->p_stowf, - this->p_hamilt_sto); + check_che(this->dos_nche, emin_sto, emax_sto, this->nbands_sto, this->p_kv, this->p_stowf, this->p_hamilt_sto); if (dos_setemax) { this->emax = dos_emax_ev; @@ -147,12 +141,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) } ModuleBase::GlobalFunc::ZEROS(allorderchi.data(), nchipk_new * npwx * dos_nche); std::complex* tmpchi = pchi + start_nchipk * npwx; - che.calpolyvec_complex(hchi_norm, - tmpchi, - allorderchi.data(), - npw, - npwx, - nchipk_new); + che.calpolyvec_complex(hchi_norm, tmpchi, allorderchi.data(), npw, npwx, nchipk_new); double* vec_all = (double*)allorderchi.data(); int LDA = npwx * nchipk_new * 2; int M = npwx * nchipk_new * 2; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_dos.h b/source/module_hamilt_pw/hamilt_stodft/sto_dos.h index 66d601c506..7b717df07c 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_dos.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_dos.h @@ -15,7 +15,7 @@ class Sto_DOS psi::Psi>* p_psi_in, hamilt::Hamilt>* p_hamilt_in, StoChe& stoche, - Stochastic_WF* p_stowf_in); + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf_in); ~Sto_DOS(); /** @@ -59,8 +59,10 @@ class Sto_DOS elecstate::ElecState* p_elec = nullptr; ///< pointer to the electronic state psi::Psi>* p_psi = nullptr; ///< pointer to the wavefunction hamilt::Hamilt>* p_hamilt; ///< pointer to the Hamiltonian - Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions - Sto_Func stofunc; ///< functions + + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf + = nullptr; ///< pointer to the stochastic wavefunctions + Sto_Func stofunc; ///< functions hamilt::HamiltSdftPW>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT }; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp index 2a6e37ec65..8692a586ee 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp @@ -1,11 +1,11 @@ #include "sto_elecond.h" -#include "module_parameter/parameter.h" #include "module_base/complexmatrix.h" #include "module_base/constants.h" #include "module_base/memory.h" #include "module_base/timer.h" #include "module_base/vector3.h" +#include "module_parameter/parameter.h" #include "sto_tool.h" #include @@ -21,7 +21,7 @@ Sto_EleCond::Sto_EleCond(UnitCell* p_ucell_in, pseudopot_cell_vnl* p_ppcell_in, hamilt::Hamilt>* p_hamilt_in, StoChe& stoche, - Stochastic_WF* p_stowf_in) + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf_in) : EleCond(p_ucell_in, p_kv_in, p_elec_in, p_wfcpw_in, p_psi_in, p_ppcell_in) { this->p_hamilt = p_hamilt_in; @@ -386,9 +386,10 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi>& kspsi_all, remain -= tmpnb; startnb += tmpnb; - if (remain == 0) { + if (remain == 0) + { break; -} + } } for (int id = 0; id < ndim; ++id) @@ -1013,8 +1014,3 @@ void Sto_EleCond::sKG(const int& smear_type, } ModuleBase::timer::tick("Sto_EleCond", "sKG"); } - -namespace GlobalTemp -{ -const ModuleBase::matrix* veff; -} diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h index 6026fa1840..77dccff792 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h @@ -17,7 +17,7 @@ class Sto_EleCond : protected EleCond pseudopot_cell_vnl* p_ppcell_in, hamilt::Hamilt>* p_hamilt_in, StoChe& stoche, - Stochastic_WF* p_stowf_in); + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf_in); ~Sto_EleCond(){}; /** * @brief Set the N order of Chebyshev expansion for conductivities @@ -59,8 +59,9 @@ class Sto_EleCond : protected EleCond int fd_nche = 0; ///< number of Chebyshev orders for Fermi-Dirac function int cond_dtbatch = 0; ///< number of time steps in a batch hamilt::Hamilt>* p_hamilt; ///< pointer to the Hamiltonian - Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions - Sto_Func stofunc; ///< functions + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf + = nullptr; ///< pointer to the stochastic wavefunctions + Sto_Func stofunc; ///< functions hamilt::HamiltSdftPW>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp index 6984bfb33e..8cb543b026 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp @@ -1,6 +1,5 @@ #include "sto_forces.h" -#include "module_parameter/parameter.h" #include "module_base/mathzone.h" #include "module_cell/module_symmetry/symmetry.h" #include "module_elecstate/elecstate.h" @@ -8,12 +7,13 @@ #include "module_elecstate/potentials/gatefield.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_io/output_log.h" +#include "module_parameter/parameter.h" // new -#include "module_hamilt_general/module_xc/xc_functional.h" #include "module_base/math_integral.h" #include "module_base/parallel_reduce.h" #include "module_base/timer.h" +#include "module_hamilt_general/module_xc/xc_functional.h" void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, const elecstate::ElecState& elec, @@ -23,11 +23,11 @@ void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, K_Vectors* pkv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf) + Stochastic_WF, base_device::DEVICE_CPU>& stowf) { - ModuleBase::timer::tick("Sto_Force","cal_force"); - ModuleBase::TITLE("Sto_Forces", "init"); - this->nat = GlobalC::ucell.nat; + ModuleBase::timer::tick("Sto_Force", "cal_force"); + ModuleBase::TITLE("Sto_Forces", "init"); + this->nat = GlobalC::ucell.nat; const ModuleBase::matrix& wg = elec.wg; const Charge* chr = elec.charge; force.create(nat, 3); @@ -43,56 +43,52 @@ void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, this->cal_force_cc(forcecc, rho_basis, chr, GlobalC::ucell); this->cal_force_scc(forcescc, rho_basis, elec.vnew, elec.vnew_exist, GlobalC::ucell); - //impose total force = 0 + // impose total force = 0 int iat = 0; - ModuleBase::matrix force_e; - if(PARAM.inp.efield_flag) - { - force_e.create( GlobalC::ucell.nat, 3); - elecstate::Efield::compute_force(GlobalC::ucell, force_e); - } + ModuleBase::matrix force_e; + if (PARAM.inp.efield_flag) + { + force_e.create(GlobalC::ucell.nat, 3); + elecstate::Efield::compute_force(GlobalC::ucell, force_e); + } ModuleBase::matrix force_gate; - if(PARAM.inp.gate_flag) + if (PARAM.inp.gate_flag) { - force_gate.create( GlobalC::ucell.nat, 3); + force_gate.create(GlobalC::ucell.nat, 3); elecstate::Gatefield::compute_force(GlobalC::ucell, force_gate); } - for (int ipol = 0; ipol < 3; ipol++) - { - double sum = 0.0; - iat = 0; - - for (int it = 0;it < GlobalC::ucell.ntype;it++) - { - for (int ia = 0;ia < GlobalC::ucell.atoms[it].na;ia++) - { - force(iat, ipol) = - forcelc(iat, ipol) - + forceion(iat, ipol) - + forcenl(iat, ipol) - + forcecc(iat, ipol) - + forcescc(iat, ipol); - - if(PARAM.inp.efield_flag) - { - force(iat,ipol) = force(iat, ipol) + force_e(iat, ipol); - } - - if(PARAM.inp.gate_flag) + for (int ipol = 0; ipol < 3; ipol++) + { + double sum = 0.0; + iat = 0; + + for (int it = 0; it < GlobalC::ucell.ntype; it++) + { + for (int ia = 0; ia < GlobalC::ucell.atoms[it].na; ia++) + { + force(iat, ipol) = forcelc(iat, ipol) + forceion(iat, ipol) + forcenl(iat, ipol) + forcecc(iat, ipol) + + forcescc(iat, ipol); + + if (PARAM.inp.efield_flag) { - force(iat,ipol) = force(iat, ipol) + force_gate(iat, ipol); + force(iat, ipol) = force(iat, ipol) + force_e(iat, ipol); } - sum += force(iat, ipol); + if (PARAM.inp.gate_flag) + { + force(iat, ipol) = force(iat, ipol) + force_gate(iat, ipol); + } - iat++; - } - } + sum += force(iat, ipol); + + iat++; + } + } - if(!(PARAM.inp.gate_flag || PARAM.inp.efield_flag)) + if (!(PARAM.inp.gate_flag || PARAM.inp.efield_flag)) { double compen = sum / GlobalC::ucell.nat; for (int iat = 0; iat < GlobalC::ucell.nat; ++iat) @@ -100,74 +96,102 @@ void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, force(iat, ipol) = force(iat, ipol) - compen; } } - } + } - if(PARAM.inp.gate_flag || PARAM.inp.efield_flag) + if (PARAM.inp.gate_flag || PARAM.inp.efield_flag) { GlobalV::ofs_running << "Atomic forces are not shifted if gate_flag or efield_flag == true!" << std::endl; } - - if(ModuleSymmetry::Symmetry::symm_flag == 1) + + if (ModuleSymmetry::Symmetry::symm_flag == 1) { double d1, d2, d3; - for(int iat=0; iatsymmetrize_vec3_nat(force.c); for (int iat = 0; iat < GlobalC::ucell.nat; iat++) { - ModuleBase::Mathzone::Direct_to_Cartesian(force(iat,0),force(iat,1),force(iat,2), - GlobalC::ucell.a1.x, GlobalC::ucell.a1.y, GlobalC::ucell.a1.z, - GlobalC::ucell.a2.x, GlobalC::ucell.a2.y, GlobalC::ucell.a2.z, - GlobalC::ucell.a3.x, GlobalC::ucell.a3.y, GlobalC::ucell.a3.z, - d1,d2,d3); - force(iat,0) = d1;force(iat,1) = d2;force(iat,2) = d3; + ModuleBase::Mathzone::Direct_to_Cartesian(force(iat, 0), + force(iat, 1), + force(iat, 2), + GlobalC::ucell.a1.x, + GlobalC::ucell.a1.y, + GlobalC::ucell.a1.z, + GlobalC::ucell.a2.x, + GlobalC::ucell.a2.y, + GlobalC::ucell.a2.z, + GlobalC::ucell.a3.x, + GlobalC::ucell.a3.y, + GlobalC::ucell.a3.z, + d1, + d2, + d3); + force(iat, 0) = d1; + force(iat, 1) = d2; + force(iat, 2) = d3; } } - GlobalV::ofs_running << setiosflags(std::ios::fixed) << std::setprecision(6) << std::endl; - if(PARAM.inp.test_force) - { + GlobalV::ofs_running << setiosflags(std::ios::fixed) << std::setprecision(6) << std::endl; + if (PARAM.inp.test_force) + { ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "LOCAL FORCE (Ry/Bohr)", forcelc); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "NONLOCAL FORCE (Ry/Bohr)", forcenl); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "NLCC FORCE (Ry/Bohr)", forcecc); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "ION FORCE (Ry/Bohr)", forceion); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "SCC FORCE (Ry/Bohr)", forcescc); - if (PARAM.inp.efield_flag) { + if (PARAM.inp.efield_flag) + { ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "EFIELD FORCE (Ry/Bohr)", force_e); -} - if (PARAM.inp.gate_flag) { + } + if (PARAM.inp.gate_flag) + { ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "GATEFIELD FORCE (Ry/Bohr)", force_gate); -} + } } // output force in unit eV/Angstrom GlobalV::ofs_running << std::endl; - - if(PARAM.inp.test_force) - { + + if (PARAM.inp.test_force) + { ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "LOCAL FORCE (eV/Angstrom)", forcelc, false); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "NONLOCAL FORCE (eV/Angstrom)", forcenl, false); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "NLCC FORCE (eV/Angstrom)", forcecc, false); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "ION FORCE (eV/Angstrom)", forceion, false); ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "SCC FORCE (eV/Angstrom)", forcescc, false); - if (PARAM.inp.efield_flag) { + if (PARAM.inp.efield_flag) + { ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "EFIELD FORCE (eV/Angstrom)", force_e, false); -} - if (PARAM.inp.gate_flag) { + } + if (PARAM.inp.gate_flag) + { ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "GATEFIELD FORCE (eV/Angstrom)", force_gate, false); -} + } } ModuleIO::print_force(GlobalV::ofs_running, GlobalC::ucell, "TOTAL-FORCE (eV/Angstrom)", force, false); ModuleBase::timer::tick("Sto_Force", "cal_force"); @@ -179,33 +203,36 @@ void Sto_Forces::cal_sto_force_nl(ModuleBase::matrix& forcenl, K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf) + Stochastic_WF, base_device::DEVICE_CPU>& stowf) { - ModuleBase::TITLE("Sto_Forces","cal_force_nl"); - ModuleBase::timer::tick("Sto_Forces","cal_force_nl"); + ModuleBase::TITLE("Sto_Forces", "cal_force_nl"); + ModuleBase::timer::tick("Sto_Forces", "cal_force_nl"); const int nkb = GlobalC::ppcell.nkb; int* nchip = stowf.nchip; - if(nkb == 0) { return; // mohan add 2010-07-25 -} - - const int npwx = wfc_basis->npwk_max; - // vkb1: |Beta(nkb,npw)> - ModuleBase::ComplexMatrix vkb1( nkb, npwx ); - int nksbands = psi_in->get_nbands(); - if(GlobalV::MY_STOGROUP != 0) { nksbands = 0; -} - + if (nkb == 0) + { + return; // mohan add 2010-07-25 + } + + const int npwx = wfc_basis->npwk_max; + // vkb1: |Beta(nkb,npw)> + ModuleBase::ComplexMatrix vkb1(nkb, npwx); + int nksbands = psi_in->get_nbands(); + if (GlobalV::MY_STOGROUP != 0) + { + nksbands = 0; + } - for (int ik = 0;ik < wfc_basis->nks;ik++) + for (int ik = 0; ik < wfc_basis->nks; ik++) { - const int nstobands = nchip[ik]; - const int nbandstot = nstobands + nksbands; - const int npw = wfc_basis->npwk[ik]; + const int nstobands = nchip[ik]; + const int nbandstot = nstobands + nksbands; + const int npw = wfc_basis->npwk[ik]; - // dbecp: conj( -iG * ) - ModuleBase::ComplexArray dbecp( 3, nbandstot, nkb); - ModuleBase::ComplexMatrix becp( nbandstot, nkb); + // dbecp: conj( -iG * ) + ModuleBase::ComplexArray dbecp(3, nbandstot, nkb); + ModuleBase::ComplexMatrix becp(nbandstot, nkb); const int current_spin = p_kv->isk[ik]; // generate vkb @@ -216,141 +243,181 @@ void Sto_Forces::cal_sto_force_nl(ModuleBase::matrix& forcenl, // get becp according to wave functions and vkb // important here ! becp must set zero!! - // vkb: Beta(nkb,npw) - // becp(nkb,nbnd): + // vkb: Beta(nkb,npw) + // becp(nkb,nbnd): becp.zero_out(); - char transa = 'C'; + char transa = 'C'; char transb = 'N'; - psi_in->fix_k(ik); - stowf.shchi->fix_k(ik); - //KS orbitals - int npmks = PARAM.globalv.npol * nksbands; - zgemm_(&transa,&transb,&nkb,&npmks,&npw,&ModuleBase::ONE, - GlobalC::ppcell.vkb.c,&npwx, - psi_in->get_pointer(),&npwx, - &ModuleBase::ZERO,becp.c,&nkb); - //stochastic orbitals - int npmsto = PARAM.globalv.npol * nstobands; - zgemm_(&transa,&transb,&nkb,&npmsto,&npw,&ModuleBase::ONE, - GlobalC::ppcell.vkb.c,&npwx, - stowf.shchi->get_pointer(),&npwx, - &ModuleBase::ZERO,&becp(nksbands,0),&nkb); - + psi_in->fix_k(ik); + stowf.shchi->fix_k(ik); + // KS orbitals + int npmks = PARAM.globalv.npol * nksbands; + zgemm_(&transa, + &transb, + &nkb, + &npmks, + &npw, + &ModuleBase::ONE, + GlobalC::ppcell.vkb.c, + &npwx, + psi_in->get_pointer(), + &npwx, + &ModuleBase::ZERO, + becp.c, + &nkb); + // stochastic orbitals + int npmsto = PARAM.globalv.npol * nstobands; + zgemm_(&transa, + &transb, + &nkb, + &npmsto, + &npw, + &ModuleBase::ONE, + GlobalC::ppcell.vkb.c, + &npwx, + stowf.shchi->get_pointer(), + &npwx, + &ModuleBase::ZERO, + &becp(nksbands, 0), + &nkb); + Parallel_Reduce::reduce_pool(becp.c, becp.size); - //out.printcm_real("becp",becp,1.0e-4); - // Calculate the derivative of beta, - // |dbeta> = -ig * |beta> + // out.printcm_real("becp",becp,1.0e-4); + // Calculate the derivative of beta, + // |dbeta> = -ig * |beta> dbecp.zero_out(); - for (int ipol = 0; ipol<3; ipol++) + for (int ipol = 0; ipol < 3; ipol++) { - for (int i = 0;i < nkb;i++) - { - std::complex* pvkb1 = &vkb1(i,0); - std::complex* pvkb = &GlobalC::ppcell.vkb(i,0); - if (ipol==0) - { - for (int ig=0; ig* pvkb1 = &vkb1(i, 0); + std::complex* pvkb = &GlobalC::ppcell.vkb(i, 0); + if (ipol == 0) + { + for (int ig = 0; ig < npw; ig++) + { pvkb1[ig] = pvkb[ig] * ModuleBase::NEG_IMAG_UNIT * wfc_basis->getgcar(ik, ig)[0]; -} + } } - if (ipol==1) - { - for (int ig=0; iggetgcar(ik,ig)[1]; -} + if (ipol == 1) + { + for (int ig = 0; ig < npw; ig++) + { + pvkb1[ig] = pvkb[ig] * ModuleBase::NEG_IMAG_UNIT * wfc_basis->getgcar(ik, ig)[1]; + } } - if (ipol==2) - { - for (int ig=0; iggetgcar(ik,ig)[2]; -} + if (ipol == 2) + { + for (int ig = 0; ig < npw; ig++) + { + pvkb1[ig] = pvkb[ig] * ModuleBase::NEG_IMAG_UNIT * wfc_basis->getgcar(ik, ig)[2]; + } } - } - //KS orbitals - zgemm_(&transa,&transb,&nkb,&npmks,&npw,&ModuleBase::ONE, - vkb1.c,&npwx, - psi_in->get_pointer(),&npwx, - &ModuleBase::ZERO,&dbecp(ipol, 0, 0),&nkb); - //stochastic orbitals - zgemm_(&transa,&transb,&nkb,&npmsto,&npw,&ModuleBase::ONE, - vkb1.c,&npwx, - stowf.shchi->get_pointer(),&npwx, - &ModuleBase::ZERO,&dbecp(ipol, nksbands, 0),&nkb); - }// end ipol - -// don't need to reduce here, keep dbecp different in each processor, -// and at last sum up all the forces. -// Parallel_Reduce::reduce_complex_double_pool( dbecp.ptr, dbecp.ndata); - -// double *cf = new double[ucell.nat*3]; -// ZEROS(cf, ucell.nat); - for (int ib=0; ibget_pointer(), + &npwx, + &ModuleBase::ZERO, + &dbecp(ipol, 0, 0), + &nkb); + // stochastic orbitals + zgemm_(&transa, + &transb, + &nkb, + &npmsto, + &npw, + &ModuleBase::ONE, + vkb1.c, + &npwx, + stowf.shchi->get_pointer(), + &npwx, + &ModuleBase::ZERO, + &dbecp(ipol, nksbands, 0), + &nkb); + } // end ipol + + // don't need to reduce here, keep dbecp different in each processor, + // and at last sum up all the forces. + // Parallel_Reduce::reduce_complex_double_pool( dbecp.ptr, dbecp.ndata); + + // double *cf = new double[ucell.nat*3]; + // ZEROS(cf, ucell.nat); + for (int ib = 0; ib < nbandstot; ib++) + { + double fac; + if (ib < nksbands) + { + fac = wg(ik, ib) * 2.0 * GlobalC::ucell.tpiba; + } + else + { fac = p_kv->wk[ik] * 2.0 * GlobalC::ucell.tpiba; -} + } int iat = 0; - int sum = 0; - for (int it=0; it< GlobalC::ucell.ntype; it++) - { - const int Nprojs = GlobalC::ucell.atoms[it].ncpp.nh; - for (int ia=0; ia< GlobalC::ucell.atoms[it].na; ia++) - { - for (int ip=0; ip * (4) cal_nl: contribution due to the non-local pseudopotential. * (4) cal_scc: contributino due to incomplete SCF calculation. */ - Sto_Forces(const int nat_in):Forces(nat_in){}; + Sto_Forces(const int nat_in) : Forces(nat_in){}; ~Sto_Forces(){}; void cal_stoforce(ModuleBase::matrix& force, @@ -28,7 +28,7 @@ class Sto_Forces : public Forces K_Vectors* pkv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf); + Stochastic_WF, base_device::DEVICE_CPU>& stowf); private: void cal_sto_force_nl(ModuleBase::matrix& forcenl, @@ -36,7 +36,7 @@ class Sto_Forces : public Forces K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf); + Stochastic_WF, base_device::DEVICE_CPU>& stowf); }; #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index b58d462204..d4d6aee52d 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -1,29 +1,32 @@ #include "sto_iter.h" -#include "module_parameter/parameter.h" #include "module_base/parallel_reduce.h" #include "module_base/timer.h" #include "module_base/tool_quit.h" #include "module_base/tool_title.h" #include "module_elecstate/occupy.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_parameter/parameter.h" -Stochastic_Iter::Stochastic_Iter() +template +Stochastic_Iter::Stochastic_Iter() { change = false; mu0 = 0; method = 2; } -Stochastic_Iter::~Stochastic_Iter() +template +Stochastic_Iter::~Stochastic_Iter() { } -void Stochastic_Iter::init(K_Vectors* pkv_in, - ModulePW::PW_Basis_K* wfc_basis, - Stochastic_WF& stowf, - StoChe& stoche, - hamilt::HamiltSdftPW>* p_hamilt_sto) +template +void Stochastic_Iter::init(K_Vectors* pkv_in, + ModulePW::PW_Basis_K* wfc_basis, + Stochastic_WF& stowf, + StoChe& stoche, + hamilt::HamiltSdftPW* p_hamilt_sto) { p_che = stoche.p_che; spolyv = stoche.spolyv; @@ -35,7 +38,8 @@ void Stochastic_Iter::init(K_Vectors* pkv_in, this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto); } -void Stochastic_Iter::orthog(const int& ik, psi::Psi>& psi, Stochastic_WF& stowf) +template +void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, Stochastic_WF& stowf) { ModuleBase::TITLE("Stochastic_Iter", "orthog"); // orthogonal part @@ -46,14 +50,14 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi>& psi, const int npwx = psi.get_nbasis(); stowf.chi0->fix_k(ik); stowf.chiortho->fix_k(ik); - std::complex*wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); + T *wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); for (int ig = 0; ig < npwx * nchipk; ++ig) { wfgout[ig] = wfgin[ig]; } // orthogonal part - std::complex* sum = new std::complex[PARAM.inp.nbands * nchipk]; + T* sum = new T[PARAM.inp.nbands * nchipk]; char transC = 'C'; char transN = 'N'; @@ -91,7 +95,11 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi>& psi, } } -void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, Stochastic_WF& stowf) +template +void Stochastic_Iter::checkemm(const int& ik, + const int istep, + const int iter, + Stochastic_WF& stowf) { ModuleBase::TITLE("Stochastic_Iter", "checkemm"); // iter = 1,2,... istep = 0,1,2,... @@ -114,7 +122,7 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S } const int norder = p_che->norder; - std::complex* pchi; + T* pchi; int ntest = 1; if (nchip[ik] < ntest) @@ -135,12 +143,13 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S while (true) { bool converge; - auto hchi_norm = std::bind(&hamilt::HamiltSdftPW>::hPsi_norm, + auto hchi_norm = std::bind(&hamilt::HamiltSdftPW::hPsi_norm, p_hamilt_sto, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - converge = p_che->checkconverge(hchi_norm, pchi, npw, stowf.npwx, *p_hamilt_sto->emax, *p_hamilt_sto->emin, 5.0); + converge + = p_che->checkconverge(hchi_norm, pchi, npw, stowf.npwx, *p_hamilt_sto->emax, *p_hamilt_sto->emin, 5.0); if (!converge) { @@ -168,7 +177,8 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S } } -void Stochastic_Iter::check_precision(const double ref, const double thr, const std::string info) +template +void Stochastic_Iter::check_precision(const double ref, const double thr, const std::string info) { //============================== // precision check @@ -210,7 +220,8 @@ void Stochastic_Iter::check_precision(const double ref, const double thr, const //=============================== } -void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) +template +void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) { ModuleBase::TITLE("Stochastic_Iter", "itermu"); ModuleBase::timer::tick("Stochastic_Iter", "itermu"); @@ -304,7 +315,8 @@ void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) return; } -void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) +template +void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) { ModuleBase::TITLE("Stochastic_Iter", "calPn"); ModuleBase::timer::tick("Stochastic_Iter", "calPn"); @@ -324,7 +336,7 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) ModuleBase::GlobalFunc::ZEROS(spolyv, norder * norder); } } - std::complex* pchi; + T* pchi; if (PARAM.inp.nbands > 0) { stowf.chiortho->fix_k(ik); @@ -336,7 +348,7 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) pchi = stowf.chi0->get_pointer(); } - auto hchi_norm = std::bind(&hamilt::HamiltSdftPW>::hPsi_norm, + auto hchi_norm = std::bind(&hamilt::HamiltSdftPW::hPsi_norm, p_hamilt_sto, std::placeholders::_1, std::placeholders::_2, @@ -351,8 +363,8 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) } else { - p_che->calpolyvec_complex(hchi_norm, pchi, stowf.chiallorder[ik].c, npw, npwx, nchip_ik); - double* vec_all = (double*)stowf.chiallorder[ik].c; + p_che->calpolyvec_complex(hchi_norm, pchi, stowf.chiallorder[ik].get_pointer(), npw, npwx, nchip_ik); + double* vec_all = (double*)stowf.chiallorder[ik].get_pointer(); char trans = 'T'; char normal = 'N'; double one = 1; @@ -366,7 +378,8 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) return; } -double Stochastic_Iter::calne(elecstate::ElecState* pes) +template +double Stochastic_Iter::calne(elecstate::ElecState* pes) { ModuleBase::timer::tick("Stochastic_Iter", "calne"); double totne = 0; @@ -408,21 +421,22 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) return totne; } -void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) +template +void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) { auto nroot_fd = std::bind(&Sto_Func::nroot_fd, &this->stofunc, std::placeholders::_1); p_che->calcoef_real(nroot_fd); for (int ik = 0; ik < this->pkv->get_nks(); ++ik) { - p_hamilt_sto->updateHk(ik); this->calTnchi_ik(ik, stowf); } } -void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, - elecstate::ElecState* pes, - hamilt::Hamilt>* pHamilt, - ModulePW::PW_Basis_K* wfc_basis) +template +void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, + elecstate::ElecState* pes, + hamilt::Hamilt* pHamilt, + ModulePW::PW_Basis_K* wfc_basis) { ModuleBase::TITLE("Stochastic_Iter", "sum_stoband"); ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband"); @@ -481,14 +495,14 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, const int nchip_ik = nchip[ik]; if (this->pkv->get_nks() > 1) { - pHamilt->updateHk(ik); + pHamilt->updateHk(ik); // can be merged with calTnchi_ik, but it does not nearly cost time. stowf.shchi->fix_k(ik); } const int npw = this->pkv->ngk[ik]; const double kweight = this->pkv->wk[ik]; - std::complex* hshchi = new std::complex[nchip_ik * npwx]; - std::complex* tmpin = stowf.shchi->get_pointer(); - std::complex* tmpout = hshchi; + T* hshchi = new T[nchip_ik * npwx]; + T* tmpin = stowf.shchi->get_pointer(); + T* tmpout = hshchi; p_hamilt_sto->hPsi(tmpin, tmpout, nchip_ik); for (int ichi = 0; ichi < nchip_ik; ++ichi) { @@ -508,11 +522,11 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, double dr3 = GlobalC::ucell.omega / wfc_basis->nxyz; double tmprho, tmpne; - std::complex outtem; + T outtem; double sto_ne = 0; ModuleBase::GlobalFunc::ZEROS(sto_rho, nrxx); - std::complex* porter = new std::complex[nrxx]; + T* porter = new T[nrxx]; double out2; double* ksrho = nullptr; @@ -527,7 +541,7 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, { const int nchip_ik = nchip[ik]; stowf.shchi->fix_k(ik); - std::complex* tmpout = stowf.shchi->get_pointer(); + T* tmpout = stowf.shchi->get_pointer(); for (int ichi = 0; ichi < nchip_ik; ++ichi) { wfc_basis->recip2real(tmpout, porter, ik); @@ -596,13 +610,14 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, return; } -void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WF& stowf) +template +void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WF& stowf) { const int npw = stowf.ngk[ik]; const int npwx = stowf.npwx; stowf.shchi->fix_k(ik); - std::complex* out = stowf.shchi->get_pointer(); - std::complex* pchi; + T* out = stowf.shchi->get_pointer(); + T* pchi; if (PARAM.inp.nbands > 0) { stowf.chiortho->fix_k(ik); @@ -616,23 +631,27 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WF& stowf) if (this->method == 2) { char transa = 'N'; - std::complex one = 1; + T one = 1; int inc = 1; - std::complex zero = 0; + T zero = 0; int LDA = npwx * nchip[ik]; int M = npwx * nchip[ik]; int N = p_che->norder; - std::complex* coef_real = new std::complex[p_che->norder]; + T* coef_real = new T[p_che->norder]; for (int i = 0; i < p_che->norder; ++i) { coef_real[i] = p_che->coef_real[i]; } - zgemv_(&transa, &M, &N, &one, stowf.chiallorder[ik].c, &LDA, coef_real, &inc, &zero, out, &inc); + zgemv_(&transa, &M, &N, &one, stowf.chiallorder[ik].get_pointer(), &LDA, coef_real, &inc, &zero, out, &inc); delete[] coef_real; } else { - auto hchi_norm = std::bind(&hamilt::HamiltSdftPW>::hPsi_norm, + if (this->pkv->get_nks() > 1) + { + p_hamilt_sto->updateHk(ik); // necessary, because itermu should be called before this function + } + auto hchi_norm = std::bind(&hamilt::HamiltSdftPW::hPsi_norm, p_hamilt_sto, std::placeholders::_1, std::placeholders::_2, @@ -640,3 +659,5 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WF& stowf) p_che->calfinalvec_real(hchi_norm, pchi, out, npw, npwx, nchip[ik]); } } + +template class Stochastic_Iter, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h index 69cd8cb304..2947f98e7f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h @@ -17,6 +17,7 @@ // rho: charge density //---------------------------------------------- +template class Stochastic_Iter { @@ -38,29 +39,29 @@ class Stochastic_Iter */ void init(K_Vectors* pkv_in, ModulePW::PW_Basis_K* wfc_basis, - Stochastic_WF& stowf, + Stochastic_WF& stowf, StoChe& stoche, - hamilt::HamiltSdftPW>* p_hamilt_sto); + hamilt::HamiltSdftPW* p_hamilt_sto); - void sum_stoband(Stochastic_WF& stowf, + void sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pes, - hamilt::Hamilt>* pHamilt, + hamilt::Hamilt* pHamilt, ModulePW::PW_Basis_K* wfc_basis); double calne(elecstate::ElecState* pes); void itermu(const int iter, elecstate::ElecState* pes); - void orthog(const int& ik, psi::Psi>& psi, Stochastic_WF& stowf); + void orthog(const int& ik, psi::Psi& psi, Stochastic_WF& stowf); - void checkemm(const int& ik, const int istep, const int iter, Stochastic_WF& stowf); + void checkemm(const int& ik, const int istep, const int iter, Stochastic_WF& stowf); void check_precision(const double ref, const double thr, const std::string info); ModuleBase::Chebyshev* p_che = nullptr; Sto_Func stofunc; - hamilt::HamiltSdftPW>* p_hamilt_sto = nullptr; + hamilt::HamiltSdftPW* p_hamilt_sto = nullptr; double mu0; // chemical potential; unit in Ry bool change; @@ -76,11 +77,11 @@ class Stochastic_Iter public: int method; // different methods 1: slow, less memory 2: fast, more memory // cal shchi = \sqrt{f(\hat{H})}|\chi> - void calHsqrtchi(Stochastic_WF& stowf); + void calHsqrtchi(Stochastic_WF& stowf); // cal Pn = \sum_\chi <\chi|Tn(\hat{h})|\chi> - void calPn(const int& ik, Stochastic_WF& stowf); + void calPn(const int& ik, Stochastic_WF& stowf); // cal Tnchi = \sum_n C_n*T_n(\hat{h})|\chi> - void calTnchi_ik(const int& ik, Stochastic_WF& stowf); + void calTnchi_ik(const int& ik, Stochastic_WF& stowf); private: K_Vectors* pkv; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp index a0a00d8b9b..73205edecd 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp @@ -1,10 +1,10 @@ #include "sto_stress_pw.h" -#include "module_parameter/parameter.h" #include "module_base/timer.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" #include "module_io/output_log.h" +#include "module_parameter/parameter.h" void Sto_Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, const elecstate::ElecState& elec, @@ -14,7 +14,7 @@ void Sto_Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf, + Stochastic_WF, base_device::DEVICE_CPU>& stowf, const Charge* const chr, pseudopot_cell_vnl* nlpp_in, const UnitCell& ucell_in) @@ -99,7 +99,7 @@ void Sto_Stress_PW::sto_stress_kin(ModuleBase::matrix& sigma, K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf) + Stochastic_WF, base_device::DEVICE_CPU>& stowf) { ModuleBase::TITLE("Sto_Stress_PW", "cal_stress"); ModuleBase::timer::tick("Sto_Stress_PW", "cal_stress"); @@ -116,9 +116,10 @@ void Sto_Stress_PW::sto_stress_kin(ModuleBase::matrix& sigma, double twobysqrtpi = 2.0 / std::sqrt(ModuleBase::PI); double* kfac = new double[npwx]; int nksbands = psi_in->get_nbands(); - if (GlobalV::MY_STOGROUP != 0) { + if (GlobalV::MY_STOGROUP != 0) + { nksbands = 0; -} + } for (int ik = 0; ik < nks; ++ik) { @@ -220,7 +221,7 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf, + Stochastic_WF, base_device::DEVICE_CPU>& stowf, pseudopot_cell_vnl* nlpp_in) { ModuleBase::TITLE("Sto_Stress_Func", "stres_nl"); @@ -238,9 +239,10 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in->get_nbands(); - if (GlobalV::MY_STOGROUP != 0) { + if (GlobalV::MY_STOGROUP != 0) + { nksbands = 0; -} + } // vkb1: |Beta(nkb,npw)> ModuleBase::ComplexMatrix vkb1(nkb, npwx); @@ -366,11 +368,14 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, { qvec = wfc_basis->getgpluskcar(ik, ig); double qm1; - if (qvec.norm2() > 1e-16) { + if (qvec.norm2() > 1e-16) + { qm1 = 1.0 / qvec.norm(); - } else { + } + else + { qm1 = 0; -} + } pdbecp_noevc[ig] -= 2.0 * pvkb[ig] * qvec0[ipol][0] * qvec0[jpol][0] * qm1 * this->ucell->tpiba; } // end ig } // end i @@ -407,11 +412,14 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, for (int ib = 0; ib < nbandstot; ++ib) { double fac; - if (ib < nksbands) { + if (ib < nksbands) + { fac = wg(ik, ib); - } else { + } + else + { fac = p_kv->wk[ik]; -} + } int iat = 0; int sum = 0; for (int it = 0; it < this->ucell->ntype; ++it) diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.h b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.h index 72d306756e..e3fc3eadd6 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.h @@ -22,7 +22,7 @@ class Sto_Stress_PW : public Stress_Func K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf, + Stochastic_WF, base_device::DEVICE_CPU>& stowf, const Charge* const chr, pseudopot_cell_vnl* nlpp_in, const UnitCell& ucell_in); @@ -34,7 +34,7 @@ class Sto_Stress_PW : public Stress_Func K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf); + Stochastic_WF, base_device::DEVICE_CPU>& stowf); void sto_stress_nl(ModuleBase::matrix& sigma, const ModuleBase::matrix& wg, @@ -43,7 +43,7 @@ class Sto_Stress_PW : public Stress_Func K_Vectors* p_kv, ModulePW::PW_Basis_K* wfc_basis, const psi::Psi>* psi_in, - Stochastic_WF& stowf, + Stochastic_WF, base_device::DEVICE_CPU>& stowf, pseudopot_cell_vnl* nlpp_in); }; #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp index 0568995b7a..8b350c7777 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp @@ -1,8 +1,8 @@ #include "sto_tool.h" -#include "module_parameter/parameter.h" #include "module_base/math_chebyshev.h" #include "module_base/timer.h" +#include "module_parameter/parameter.h" #ifdef __MPI #include "mpi.h" #endif @@ -13,7 +13,7 @@ void check_che(const int& nche_in, const double& try_emax, const int& nbands_sto, K_Vectors* p_kv, - Stochastic_WF* p_stowf, + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf, hamilt::HamiltSdftPW>* p_hamilt_sto) { //------------------------------ @@ -74,7 +74,13 @@ void check_che(const int& nche_in, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - converge = chetest.checkconverge(hchi_norm, pchi, npw, p_stowf->npwx, *p_hamilt_sto->emax, *p_hamilt_sto->emin, 2.0); + converge = chetest.checkconverge(hchi_norm, + pchi, + npw, + p_stowf->npwx, + *p_hamilt_sto->emax, + *p_hamilt_sto->emin, + 2.0); if (!converge) { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_tool.h b/source/module_hamilt_pw/hamilt_stodft/sto_tool.h index 006be5fd2a..83b941141e 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_tool.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_tool.h @@ -15,7 +15,7 @@ void check_che(const int& nche_in, const double& try_emax, const int& nbands_sto, K_Vectors* p_kv, - Stochastic_WF* p_stowf, + Stochastic_WF, base_device::DEVICE_CPU>* p_stowf, hamilt::HamiltSdftPW>* p_hamilt_sto); #ifndef PARALLEL_DISTRIBUTION diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 39f4697b56..1f596b06db 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -1,31 +1,38 @@ #include "sto_wf.h" +#include "module_base/memory.h" #include "module_parameter/parameter.h" -#include -#include "module_base/memory.h" +#include #include //---------Temporary------------------------------------ -#include "module_base/complexmatrix.h" #include "module_base/global_function.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" //------------------------------------------------------ -Stochastic_WF::Stochastic_WF() +template +Stochastic_WF::Stochastic_WF() { } -Stochastic_WF::~Stochastic_WF() +template +Stochastic_WF::~Stochastic_WF() { - delete chi0; + delete chi0_cpu; + Device* ctx = {}; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + delete chi0; + } delete shchi; delete chiortho; delete[] nchip; delete[] chiallorder; } -void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) +template +void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) { this->nks = p_kv->get_nks(); this->ngk = p_kv->ngk.data(); @@ -38,22 +45,26 @@ void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) } } -void Stochastic_WF::allocate_chiallorder(const int& norder) +template +void Stochastic_WF::allocate_chiallorder(const int& norder) { - this->chiallorder = new ModuleBase::ComplexMatrix[this->nks]; + this->chiallorder = new psi::Psi[this->nks]; for (int ik = 0; ik < this->nks; ++ik) { - chiallorder[ik].create(this->nchip[ik] * this->npwx, norder,true); + chiallorder[ik].resize(1, this->nchip[ik] * this->npwx, norder); + setmem_complex_op()(chiallorder[ik].get_device(), chiallorder[ik].get_pointer(), 0, chiallorder[ik].size()); } } -void Stochastic_WF::clean_chiallorder() +template +void Stochastic_WF::clean_chiallorder() { delete[] chiallorder; chiallorder = nullptr; } -void Init_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in) +template +void Stochastic_WF::init_sto_orbitals(const int seed_in) { if (seed_in == 0 || seed_in == -1) { @@ -64,11 +75,12 @@ void Init_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in) srand((unsigned)std::abs(seed_in) + GlobalV::MY_RANK * 10000); } - Allocate_Chi0(stowf); - Update_Sto_Orbitals(stowf, seed_in); + this->allocate_chi0(); + this->update_sto_orbitals(seed_in); } -void Allocate_Chi0(Stochastic_WF& stowf) +template +void Stochastic_WF::allocate_chi0() { bool firstrankmore = false; int igroup = 0; @@ -84,12 +96,12 @@ void Allocate_Chi0(Stochastic_WF& stowf) igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1; } const int nchi = PARAM.inp.nbands_sto; - const int npwx = stowf.npwx; - const int nks = stowf.nks; + const int npwx = this->npwx; + const int nks = this->nks; const int ngroup = PARAM.inp.bndpar; if (ngroup <= 0) { - ModuleBase::WARNING_QUIT("Init_Sto_Orbitals", "ngroup <= 0!"); + ModuleBase::WARNING_QUIT("init_sto_orbitals", "ngroup <= 0!"); } int tmpnchip = int(nchi / ngroup); if (igroup < nchi % ngroup) @@ -97,51 +109,65 @@ void Allocate_Chi0(Stochastic_WF& stowf) ++tmpnchip; } - stowf.nchip_max = tmpnchip; - size_t size = stowf.nchip_max * npwx * nks; - stowf.chi0 = new psi::Psi>(nks, stowf.nchip_max, npwx, stowf.ngk); - ModuleBase::Memory::record("SDFT::chi0", size * sizeof(std::complex)); + this->nchip_max = tmpnchip; + size_t size = this->nchip_max * npwx * nks; + this->chi0_cpu = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(T)); for (int ik = 0; ik < nks; ++ik) { - stowf.nchip[ik] = tmpnchip; + this->nchip[ik] = tmpnchip; + } + + // allocate chi0 + Device* ctx = {}; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + } + else + { + this->chi0 = reinterpret_cast*>(this->chi0_cpu); } } -void Update_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in) +template +void Stochastic_WF::update_sto_orbitals(const int seed_in) { const int nchi = PARAM.inp.nbands_sto; - stowf.chi0->fix_k(0); + this->chi0_cpu->fix_k(0); if (seed_in >= 0) { - for (int i = 0; i < stowf.chi0->size(); ++i) + for (int i = 0; i < this->chi0_cpu->size(); ++i) { const double phi = 2 * ModuleBase::PI * rand() / double(RAND_MAX); - stowf.chi0->get_pointer()[i] = std::complex(cos(phi), sin(phi)) / sqrt(double(nchi)); + this->chi0_cpu->get_pointer()[i] = std::complex(cos(phi), sin(phi)) / sqrt(double(nchi)); } } else { - for (int i = 0; i < stowf.chi0->size(); ++i) + for (int i = 0; i < this->chi0_cpu->size(); ++i) { if (rand() / double(RAND_MAX) < 0.5) { - stowf.chi0->get_pointer()[i] = -1.0 / sqrt(double(nchi)); + this->chi0_cpu->get_pointer()[i] = -1.0 / sqrt(double(nchi)); } else { - stowf.chi0->get_pointer()[i] = 1.0 / sqrt(double(nchi)); + this->chi0_cpu->get_pointer()[i] = 1.0 / sqrt(double(nchi)); } } } + this->sync_chi0(); } #ifdef __MPI -void Init_Com_Orbitals(Stochastic_WF& stowf) +template +void Stochastic_WF::init_com_orbitals() { const bool firstrankmore = false; - const int npwx = stowf.npwx; - const int nks = stowf.nks; + const int npwx = this->npwx; + const int nks = this->nks; int igroup; // former processor calculate more bands if (firstrankmore) @@ -162,7 +188,7 @@ void Init_Com_Orbitals(Stochastic_WF& stowf) for (int ik = 0; ik < nks; ++ik) { int* npwip = new int[n_in_pool]; - const int npw = stowf.ngk[ik]; + const int npw = this->ngk[ik]; totnpw[ik] = 0; MPI_Allgather(&npw, 1, MPI_INT, npwip, 1, MPI_INT, POOL_WORLD); @@ -176,22 +202,22 @@ void Init_Com_Orbitals(Stochastic_WF& stowf) { ++tmpnchip; } - stowf.nchip[ik] = tmpnchip; - stowf.nchip_max = std::max(tmpnchip, stowf.nchip_max); + this->nchip[ik] = tmpnchip; + this->nchip_max = std::max(tmpnchip, this->nchip_max); delete[] npwip; } - size_t size = stowf.nchip_max * npwx * nks; - stowf.chi0 = new psi::Psi>(nks, stowf.nchip_max, npwx, stowf.ngk); - stowf.chi0->zero_out(); - ModuleBase::Memory::record("SDFT::chi0", size * sizeof(std::complex)); + size_t size = this->nchip_max * npwx * nks; + this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk); + this->chi0_cpu->zero_out(); + ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) { int* npwip = new int[n_in_pool]; - const int npw = stowf.ngk[ik]; + const int npw = this->ngk[ik]; MPI_Allgather(&npw, 1, MPI_INT, npwip, 1, MPI_INT, POOL_WORLD); const int re = totnpw[ik] % ngroup; int ip = 0, ig0 = 0; - const int nchipk = stowf.nchip[ik]; + const int nchipk = this->nchip[ik]; // give value to orbitals in one parallel group one by one. for (int ichi = 0; ichi < nchipk; ++ichi) { @@ -215,42 +241,64 @@ void Init_Com_Orbitals(Stochastic_WF& stowf) } if (i_in_pool == ip) { - stowf.chi0->operator()(ik, ichi, ig) = 1; + this->chi0_cpu->operator()(ik, ichi, ig) = 1; } } delete[] npwip; } delete[] totnpw; + // allocate chi0 + Device* ctx = {}; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + } + else + { + this->chi0 = reinterpret_cast*>(this->chi0_cpu); + } } #else -void Init_Com_Orbitals(Stochastic_WF& stowf) +template +void Stochastic_WF::init_com_orbitals() { - const int npwx = stowf.npwx; - const int nks = stowf.nks; - size_t size = stowf.nchip_max * npwx * nks; - stowf.chi0 = new psi::Psi>(nks, npwx, npwx, stowf.ngk); - stowf.chi0->zero_out(); - ModuleBase::Memory::record("SDFT::chi0", size * sizeof(std::complex)); + const int npwx = this->npwx; + const int nks = this->nks; + size_t size = this->nchip_max * npwx * nks; + this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk); + this->chi0_cpu->zero_out(); + ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) { - const int npw = stowf.ngk[ik]; - stowf.nchip[ik] = npwx; - stowf.nchip_max = npwx; + const int npw = this->ngk[ik]; + this->nchip[ik] = npwx; + this->nchip_max = npwx; for (int ichi = 0; ichi < npw; ++ichi) { - stowf.chi0->operator()(ik, ichi, ichi) = 1; + this->chi0_cpu->operator()(ik, ichi, ichi) = 1; } } + + // allocate chi0 + Device* ctx = {}; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + } + else + { + this->chi0 = reinterpret_cast*>(this->chi0_cpu); + } } #endif -void Init_Sto_Orbitals_Ecut(Stochastic_WF& stowf, - const int seed_in, - const K_Vectors& kv, - const ModulePW::PW_Basis_K& wfcpw, - const int max_ecut) +template +void Stochastic_WF::init_sto_orbitals_Ecut(const int seed_in, + const K_Vectors& kv, + const ModulePW::PW_Basis_K& wfcpw, + const int max_ecut) { - Allocate_Chi0(stowf); + this->allocate_chi0(); ModulePW::PW_Basis pwmax; #ifdef __MPI @@ -265,7 +313,7 @@ void Init_Sto_Orbitals_Ecut(Stochastic_WF& stowf, const int nchitot = PARAM.inp.nbands_sto; bool* updown = new bool[nx * ny * nz]; int* nrecv = new int[PARAM.inp.bndpar]; - const int nchiper = stowf.nchip[0]; + const int nchiper = this->nchip[0]; #ifdef __MPI MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, PARAPW_WORLD); #endif @@ -306,11 +354,11 @@ void Init_Sto_Orbitals_Ecut(Stochastic_WF& stowf, { if (updown[ig2ixyz[ig]]) { - stowf.chi0->operator()(ik, ichi, ig) = -1.0 / sqrt(double(nchitot)); + this->chi0_cpu->operator()(ik, ichi, ig) = -1.0 / sqrt(double(nchitot)); } else { - stowf.chi0->operator()(ik, ichi, ig) = 1.0 / sqrt(double(nchitot)); + this->chi0_cpu->operator()(ik, ichi, ig) = 1.0 / sqrt(double(nchitot)); } } } @@ -319,3 +367,19 @@ void Init_Sto_Orbitals_Ecut(Stochastic_WF& stowf, delete[] nrecv; delete[] updown; } + +template +void Stochastic_WF::sync_chi0() +{ + Device* ctx = {}; + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + syncmem_h2d_op()(this->chi0->get_device(), + this->chi0_cpu->get_device(), + this->chi0->get_pointer(), + this->chi0_cpu->get_pointer(), + this->chi0_cpu->size()); + } +} + +template class Stochastic_WF, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h index f8811f8bb0..e6f954f9af 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h @@ -1,7 +1,7 @@ #ifndef STOCHASTIC_WF_H #define STOCHASTIC_WF_H -#include "module_base/complexmatrix.h" +#include "module_base/module_container/ATen/tensor.h" #include "module_basis/module_pw/pw_basis_k.h" #include "module_cell/klist.h" #include "module_psi/psi.h" @@ -9,6 +9,7 @@ //---------------------------------------------- // Generate stochastic wave functions //---------------------------------------------- +template class Stochastic_WF { public: @@ -18,12 +19,14 @@ class Stochastic_WF void init(K_Vectors* p_kv, const int npwx_in); - // origin stochastic wavefunctions in real space - psi::Psi>* chi0 = nullptr; + // origin stochastic wavefunctions in CPU + psi::Psi* chi0_cpu = nullptr; + // origin stochastic wavefunctions in GPU or CPU + psi::Psi* chi0 = nullptr; // stochastic wavefunctions after in reciprocal space orthogonalized with KS wavefunctions - psi::Psi>* chiortho = nullptr; + psi::Psi* chiortho = nullptr; // sqrt(f(H))|chi> - psi::Psi>* shchi = nullptr; + psi::Psi* shchi = nullptr; int nchi = 0; ///< Total number of stochatic obitals int* nchip = nullptr; ///< The number of stochatic orbitals in current process of each k point. int nchip_max = 0; ///< Max number of stochastic orbitals among all k points. @@ -34,25 +37,32 @@ class Stochastic_WF int nbands_total = 0; ///< number of bands in total, nbands_total=nchi+nbands_diag; public: // Tn(H)|chi> - ModuleBase::ComplexMatrix* chiallorder = nullptr; + psi::Psi* chiallorder = nullptr; // allocate chiallorder void allocate_chiallorder(const int& norder); // chiallorder cost too much memories and should be cleaned after scf. void clean_chiallorder(); + + public: + // init stochastic orbitals + void init_sto_orbitals(const int seed_in); + // init stochastic orbitals from a large Ecut + // It can test the convergence of SDFT with respect to Ecut + void init_sto_orbitals_Ecut(const int seed_in, + const K_Vectors& kv, + const ModulePW::PW_Basis_K& wfcpw, + const int max_ecut); + // allocate chi0 + void allocate_chi0(); + // update stochastic orbitals + void update_sto_orbitals(const int seed_in); + // init complete orbitals + void init_com_orbitals(); + // sync chi0 from CPU to GPU + void sync_chi0(); + + protected: + using setmem_complex_op = base_device::memory::set_memory_op; + using syncmem_h2d_op = base_device::memory::synchronize_memory_op; }; -// init stochastic orbitals -void Init_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in); -// init stochastic orbitals from a large Ecut -// It can test the convergence of SDFT with respect to Ecut -void Init_Sto_Orbitals_Ecut(Stochastic_WF& stowf, - const int seed_in, - const K_Vectors& kv, - const ModulePW::PW_Basis_K& wfcpw, - const int max_ecut); -// allocate chi0 -void Allocate_Chi0(Stochastic_WF& stowf); -// update stochastic orbitals -void Update_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in); -// init complete orbitals -void Init_Com_Orbitals(Stochastic_WF& stowf); #endif diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index a8ca9f5b16..a9ccac3380 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -12,7 +12,7 @@ namespace hsolver template class HSolverPW { - private: + protected: // Note GetTypeReal::type will // return T if T is real type(float, double), // otherwise return the real type of T(complex, complex) @@ -84,12 +84,12 @@ class HSolverPW const bool need_subspace; // for cg or dav_subspace const bool initialed_psi; - private: + protected: Device* ctx = {}; int rank_in_pool = 0; int nproc_in_pool = 1; - + private: /// @brief calculate the threshold for iterative-diagonalization for each band void cal_ethr_band(const double& wk, const double* wg, const double& ethr, std::vector& ethrs); diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 59e86fa57d..d71e23d29b 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -9,15 +9,15 @@ namespace hsolver { - -void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, - psi::Psi>& psi, - elecstate::ElecState* pes, - ModulePW::PW_Basis_K* wfc_basis, - Stochastic_WF& stowf, - const int istep, - const int iter, - const bool skip_charge) +template +void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, + psi::Psi& psi, + elecstate::ElecState* pes, + ModulePW::PW_Basis_K* wfc_basis, + Stochastic_WF& stowf, + const int istep, + const int iter, + const bool skip_charge) { ModuleBase::TITLE("HSolverPW_SDFT", "solve"); ModuleBase::timer::tick("HSolverPW_SDFT", "solve"); @@ -44,7 +44,7 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, { this->updatePsiK(pHamilt, psi, ik); // template add precondition calculating here - update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0()); + this->update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0()); /// solve eigenvector and eigenvalue for H(k) double* p_eigenvalues = &(pes->ekb(ik, 0)); this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues); @@ -68,12 +68,15 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, // init k if (nks > 1) { - pHamilt->updateHk(ik); + pHamilt->updateHk(ik); // necessary , because emax and emin should be decided first } stoiter.calPn(ik, stowf); } + // iterate to get mu stoiter.itermu(iter, pes); + + // prepare sqrt{f(\hat{H})}|\chi> to calculate density, force and stress stoiter.calHsqrtchi(stowf); if (skip_charge) { @@ -104,4 +107,6 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, return; } +// template class HSolverPW_SDFT, base_device::DEVICE_CPU>; +template class HSolverPW_SDFT, base_device::DEVICE_CPU>; } // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/hsolver_pw_sdft.h b/source/module_hsolver/hsolver_pw_sdft.h index 48ef2b38e5..6fc0a39fec 100644 --- a/source/module_hsolver/hsolver_pw_sdft.h +++ b/source/module_hsolver/hsolver_pw_sdft.h @@ -1,19 +1,20 @@ #ifndef HSOLVERPW_SDFT_H #define HSOLVERPW_SDFT_H #include "hsolver_pw.h" -#include "module_hamilt_pw/hamilt_stodft/sto_iter.h" #include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h" +#include "module_hamilt_pw/hamilt_stodft/sto_iter.h" namespace hsolver { -class HSolverPW_SDFT : public HSolverPW> +template +class HSolverPW_SDFT : public HSolverPW { public: HSolverPW_SDFT(K_Vectors* pkv, ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in, - Stochastic_WF& stowf, + Stochastic_WF& stowf, StoChe& stoche, - hamilt::HamiltSdftPW>* p_hamilt_sto, + hamilt::HamiltSdftPW* p_hamilt_sto, const std::string calculation_type_in, const std::string basis_type_in, const std::string method_in, @@ -25,33 +26,33 @@ class HSolverPW_SDFT : public HSolverPW> const double diag_thr_in, const bool need_subspace_in, const bool initialed_psi_in) - : HSolverPW(wfc_basis_in, - pwf_in, - calculation_type_in, - basis_type_in, - method_in, - use_paw_in, - use_uspp_in, - nspin_in, - scf_iter_in, - diag_iter_max_in, - diag_thr_in, - need_subspace_in, - initialed_psi_in) + : HSolverPW(wfc_basis_in, + pwf_in, + calculation_type_in, + basis_type_in, + method_in, + use_paw_in, + use_uspp_in, + nspin_in, + scf_iter_in, + diag_iter_max_in, + diag_thr_in, + need_subspace_in, + initialed_psi_in) { stoiter.init(pkv, wfc_basis_in, stowf, stoche, p_hamilt_sto); } - void solve(hamilt::Hamilt>* pHamilt, - psi::Psi>& psi, + void solve(hamilt::Hamilt* pHamilt, + psi::Psi& psi, elecstate::ElecState* pes, ModulePW::PW_Basis_K* wfc_basis, - Stochastic_WF& stowf, + Stochastic_WF& stowf, const int istep, const int iter, const bool skip_charge); - Stochastic_Iter stoiter; + Stochastic_Iter stoiter; }; } // namespace hsolver #endif \ No newline at end of file diff --git a/source/module_hsolver/test/hsolver_supplementary_mock.h b/source/module_hsolver/test/hsolver_supplementary_mock.h index fd7d482d09..59e07df85f 100644 --- a/source/module_hsolver/test/hsolver_supplementary_mock.h +++ b/source/module_hsolver/test/hsolver_supplementary_mock.h @@ -1,5 +1,6 @@ #pragma once #include "module_elecstate/elecstate.h" +#include "module_hamilt_pw/hamilt_pwdft/wavefunc.h" namespace elecstate { @@ -11,9 +12,7 @@ const double* ElecState::getRho(int spin) const return &(this->charge->rho[spin][0]); } -void ElecState::fixed_weights(const std::vector& ocp_kb, - const int &nbands, - const double &nelec) +void ElecState::fixed_weights(const std::vector& ocp_kb, const int& nbands, const double& nelec) { return; } @@ -43,7 +42,10 @@ void ElecState::print_eigenvalue(std::ofstream& ofs) return; } -void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac, ModuleSymmetry::Symmetry&, const void*) +void ElecState::init_scf(const int istep, + const ModuleBase::ComplexMatrix& strucfac, + ModuleSymmetry::Symmetry&, + const void*) { return; } @@ -57,18 +59,24 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge return; } -Potential::~Potential(){} +Potential::~Potential() +{ +} -void Potential::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff){} +void Potential::cal_v_eff(const Charge* const chg, const UnitCell* const ucell, ModuleBase::matrix& v_eff) +{ +} -void Potential::cal_fixed_v(double* vl_pseudo){} +void Potential::cal_fixed_v(double* vl_pseudo) +{ +} } // namespace elecstate - -//mock of Stochastic_WF +// mock of Stochastic_WF #include "module_hamilt_pw/hamilt_stodft/sto_wf.h" -Stochastic_WF::Stochastic_WF() +template +Stochastic_WF::Stochastic_WF() { chiortho = nullptr; chi0 = nullptr; @@ -76,7 +84,8 @@ Stochastic_WF::Stochastic_WF() nchip = nullptr; } -Stochastic_WF::~Stochastic_WF() +template +Stochastic_WF::~Stochastic_WF() { delete[] chi0; delete[] shchi; @@ -84,7 +93,8 @@ Stochastic_WF::~Stochastic_WF() delete[] nchip; } -void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) +template +void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) { /*chi0 = new ModuleBase::ComplexMatrix[nks_in]; shchi = new ModuleBase::ComplexMatrix[nks_in]; @@ -94,8 +104,12 @@ void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) } #include "module_cell/klist.h" -K_Vectors::K_Vectors(){} -K_Vectors::~K_Vectors(){} +K_Vectors::K_Vectors() +{ +} +K_Vectors::~K_Vectors() +{ +} wavefunc::wavefunc() { } diff --git a/source/module_hsolver/test/test_hsolver.cpp b/source/module_hsolver/test/test_hsolver.cpp index 4f5adb96ac..8e2d993ecd 100644 --- a/source/module_hsolver/test/test_hsolver.cpp +++ b/source/module_hsolver/test/test_hsolver.cpp @@ -3,8 +3,8 @@ #include #define protected public -#include "module_hsolver/hsolver.h" #include "hsolver_supplementary_mock.h" +#include "module_hsolver/hsolver.h" #include @@ -28,36 +28,36 @@ * - diag() for Psi(FPTYPE) case * - destructor of DiagH and HSolver * - * the definition of supplementary functions is added in hsolver_supplementary_mock.h + * the definition of supplementary functions is added in hsolver_supplementary_mock.h */ class TestHSolver : public ::testing::Test { -public: - // hsolver::HSolver, base_device::DEVICE_CPU> hs_cf; - // hsolver::HSolver, base_device::DEVICE_CPU> hs_cd; - // hsolver::HSolver hs_f; - // hsolver::HSolver hs_d; + public: + // hsolver::HSolver, base_device::DEVICE_CPU> hs_cf; + // hsolver::HSolver, base_device::DEVICE_CPU> hs_cd; + // hsolver::HSolver hs_f; + // hsolver::HSolver hs_d; - hamilt::Hamilt> hamilt_test_cd; - hamilt::Hamilt> hamilt_test_cf; - psi::Psi> psi_test_cd; - psi::Psi> psi_test_cf; + hamilt::Hamilt> hamilt_test_cd; + hamilt::Hamilt> hamilt_test_cf; + psi::Psi> psi_test_cd; + psi::Psi> psi_test_cf; - hamilt::Hamilt hamilt_test_d; - hamilt::Hamilt hamilt_test_f; - psi::Psi psi_test_d; - psi::Psi psi_test_f; + hamilt::Hamilt hamilt_test_d; + hamilt::Hamilt hamilt_test_f; + psi::Psi psi_test_d; + psi::Psi psi_test_f; - Stochastic_WF stowf_test; + Stochastic_WF> stowf_test; - elecstate::ElecState elecstate_test; + elecstate::ElecState elecstate_test; - ModulePW::PW_Basis_K* wfcpw; + ModulePW::PW_Basis_K* wfcpw; - std::string method_test = "none"; + std::string method_test = "none"; - std::ofstream temp_ofs; + std::ofstream temp_ofs; }; // TEST_F(TestHSolver, solve) diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index 3006aff076..687bc38bc7 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -7,108 +7,113 @@ #define private public #define protected public -#include "module_hsolver/hsolver_pw.h" -#include "hsolver_supplementary_mock.h" #include "hsolver_pw_sup.h" -#include "module_hsolver/hsolver_pw_sdft.h" +#include "hsolver_supplementary_mock.h" #include "module_base/global_variable.h" +#include "module_hsolver/hsolver_pw.h" +#include "module_hsolver/hsolver_pw_sdft.h" #undef private #undef protected -//mock for module_sdft -template -Sto_Func::Sto_Func(){} +// mock for module_sdft +template +Sto_Func::Sto_Func() +{ +} template class Sto_Func; -template +template StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) { this->nche = nche; } -template -StoChe::~StoChe(){} +template +StoChe::~StoChe() +{ +} template class StoChe; -Stochastic_Iter::Stochastic_Iter() +template +Stochastic_Iter::Stochastic_Iter() { change = false; mu0 = 0; method = 2; } -Stochastic_Iter::~Stochastic_Iter(){}; +template +Stochastic_Iter::~Stochastic_Iter(){}; +template class Stochastic_Iter, base_device::DEVICE_CPU>; -void Stochastic_Iter::init(K_Vectors* pkv_in, - ModulePW::PW_Basis_K* wfc_basis, - Stochastic_WF& stowf, - StoChe& stoche, - hamilt::HamiltSdftPW>* p_hamilt_sto) +template +void Stochastic_Iter::init(K_Vectors* pkv_in, + ModulePW::PW_Basis_K* wfc_basis, + Stochastic_WF& stowf, + StoChe& stoche, + hamilt::HamiltSdftPW* p_hamilt_sto) { - this->nchip = stowf.nchip;; + this->nchip = stowf.nchip; + ; this->targetne = 1; this->method = stoche.method_sto; } -void Stochastic_Iter::orthog(const int& ik, - psi::Psi& psi, - Stochastic_WF& stowf) +template +void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, Stochastic_WF& stowf) { - //do something to verify this function has been called - for(int i=0;i +void Stochastic_Iter::checkemm(const int& ik, + const int istep, + const int iter, + Stochastic_WF& stowf) { - //do something to verify this function has been called + // do something to verify this function has been called stowf.nchi++; return; } -void Stochastic_Iter::calPn( - const int &ik, - Stochastic_WF &stowf -) +template +void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) { - //do something to verify this function has been called - stowf.nbands_diag ++; + // do something to verify this function has been called + stowf.nbands_diag++; return; } -void Stochastic_Iter::itermu( - int iter, - elecstate::ElecState *pes -) +template +void Stochastic_Iter::itermu(int iter, elecstate::ElecState* pes) { - //do something to verify this function has been called + // do something to verify this function has been called pes->f_en.eband += 1.2; return; } -void Stochastic_Iter::calHsqrtchi(Stochastic_WF &stowf) +template +void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) { - //do something to verify this function has been called + // do something to verify this function has been called stowf.nchip_max++; return; } -void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, - elecstate::ElecState* pes, - hamilt::Hamilt, base_device::DEVICE_CPU>* pHamilt, - ModulePW::PW_Basis_K* wfc_basis) +template +void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, + elecstate::ElecState* pes, + hamilt::Hamilt* pHamilt, + ModulePW::PW_Basis_K* wfc_basis) { - //do something to verify this function has been called - stowf.nbands_total ++; + // do something to verify this function has been called + stowf.nbands_total++; return; } @@ -130,76 +135,73 @@ Charge::~Charge(){}; */ class TestHSolverPW_SDFT : public ::testing::Test { - public: - TestHSolverPW_SDFT():stoche(8,1,0,0){} + public: + TestHSolverPW_SDFT() : stoche(8, 1, 0, 0) + { + } ModulePW::PW_Basis_K pwbk; - Stochastic_WF stowf; + Stochastic_WF> stowf; K_Vectors kv; wavefunc wf; StoChe stoche; hamilt::HamiltSdftPW>* p_hamilt_sto = nullptr; - hsolver::HSolverPW_SDFT hs_d = hsolver::HSolverPW_SDFT(&kv, - &pwbk, - &wf, - stowf, - stoche, - p_hamilt_sto, - "scf", - "pw", - "cg", - false, - PARAM.sys.use_uspp, - PARAM.input.nspin, - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, - hsolver::DiagoIterAssist>::need_subspace, - false); + hsolver::HSolverPW_SDFT, base_device::DEVICE_CPU> hs_d + = hsolver::HSolverPW_SDFT, base_device::DEVICE_CPU>( + &kv, + &pwbk, + &wf, + stowf, + stoche, + p_hamilt_sto, + "scf", + "pw", + "cg", + false, + PARAM.sys.use_uspp, + PARAM.input.nspin, + hsolver::DiagoIterAssist>::SCF_ITER, + hsolver::DiagoIterAssist>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist>::PW_DIAG_THR, + hsolver::DiagoIterAssist>::need_subspace, + false); hamilt::Hamilt> hamilt_test_d; - psi::Psi> psi_test_cd; + psi::Psi> psi_test_cd; psi::Psi> psi_test_no; - elecstate::ElecState elecstate_test; + elecstate::ElecState elecstate_test; - std::string method_test = "cg"; + std::string method_test = "cg"; - std::ofstream temp_ofs; + std::ofstream temp_ofs; }; TEST_F(TestHSolverPW_SDFT, solve) { - //initial memory and data - elecstate_test.ekb.create(1,2); + // initial memory and data + elecstate_test.ekb.create(1, 2); elecstate_test.pot = new elecstate::Potential; elecstate_test.f_en.eband = 0.0; stowf.nbands_diag = 0; stowf.nbands_total = 0; stowf.nchi = 0; stowf.nchip_max = 0; - psi_test_cd.resize(1, 2, 3); - PARAM.input.nelec = 1.0; + psi_test_cd.resize(1, 2, 3); + PARAM.input.nelec = 1.0; GlobalV::MY_STOGROUP = 0.0; int istep = 0; int iter = 0; - this->hs_d.solve(&hamilt_test_d, - psi_test_cd, - &elecstate_test, - &pwbk, - stowf, - istep, - iter, - false); - EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); - EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0); - EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0); - for(int i=0;ihs_d.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false); + EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); + EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0); + EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0); + for (int i = 0; i < psi_test_cd.size(); i++) + { + // std::cout<<__FILE__<<__LINE__<<" "<hs_d.solve(&hamilt_test_d, - psi_test_no, - &elecstate_test, - &pwbk, - stowf, - istep, - iter, - false - ); - EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); + this->hs_d.solve(&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false); + EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); EXPECT_EQ(stowf.nbands_diag, 2); EXPECT_EQ(stowf.nbands_total, 1); EXPECT_EQ(stowf.nchi, 2); @@ -256,16 +250,8 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) std::cout<<__FILE__<<__LINE__<<" "<hs_d.solve(&hamilt_test_d, - psi_test_no, - &elecstate_test, - &pwbk, - stowf, - istep, - iter, - true - ); + // test for skip charge + this->hs_d.solve(&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true); EXPECT_EQ(stowf.nbands_diag, 4); EXPECT_EQ(stowf.nbands_total, 1); EXPECT_EQ(stowf.nchi, 4); @@ -275,27 +261,26 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) delete[] elecstate_test.charge->rho[0]; delete[] elecstate_test.charge->rho; delete elecstate_test.charge; - } #ifdef __MPI -#include "mpi.h" #include "module_base/timer.h" -int main(int argc, char **argv) +#include "mpi.h" +int main(int argc, char** argv) { - ModuleBase::timer::disable(); - MPI_Init(&argc, &argv); - testing::InitGoogleTest(&argc, argv); + ModuleBase::timer::disable(); + MPI_Init(&argc, &argv); + testing::InitGoogleTest(&argc, argv); - MPI_Comm_size(MPI_COMM_WORLD,&GlobalV::NPROC); - MPI_Comm_rank(MPI_COMM_WORLD,&GlobalV::MY_RANK); + MPI_Comm_size(MPI_COMM_WORLD, &GlobalV::NPROC); + MPI_Comm_rank(MPI_COMM_WORLD, &GlobalV::MY_RANK); MPI_Comm_split(MPI_COMM_WORLD, 0, 1, &PARAPW_WORLD); - int result = RUN_ALL_TESTS(); - + int result = RUN_ALL_TESTS(); + MPI_Comm_free(&PARAPW_WORLD); - MPI_Finalize(); - - return result; + MPI_Finalize(); + + return result; } #endif \ No newline at end of file