From 38f89fc5900f4bae5b22de3a8ddec8659e15c65d Mon Sep 17 00:00:00 2001 From: YuLiu98 Date: Thu, 7 Nov 2024 10:09:21 +0800 Subject: [PATCH 1/3] Refactor: refactor iter_finish and iter_init --- source/module_esolver/esolver_ks.cpp | 154 +++++++++++----------- source/module_esolver/esolver_ks.h | 87 ++++++------ source/module_esolver/esolver_ks_lcao.cpp | 3 + source/module_esolver/esolver_ks_pw.cpp | 3 + 4 files changed, 121 insertions(+), 126 deletions(-) diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 5e9da976d0..bd56912a00 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -427,49 +427,11 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) this->niter = this->maxniter; // 4) SCF iterations - double diag_ethr = PARAM.inp.pw_diag_thr; + this->diag_ethr = PARAM.inp.pw_diag_thr; std::cout << " * * * * * *\n << Start SCF iteration." << std::endl; for (int iter = 1; iter <= this->maxniter; ++iter) { - // 5) write head - ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname); - -#ifdef __MPI - auto iterstart = MPI_Wtime(); -#else - auto iterstart = std::chrono::system_clock::now(); -#endif - - if (PARAM.inp.esolver_type == "ksdft") - { - diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type, - PARAM.inp.esolver_type, - PARAM.inp.calculation, - PARAM.inp.init_chg, - PARAM.inp.precision, - istep, - iter, - drho, - PARAM.inp.pw_diag_thr, - diag_ethr, - PARAM.inp.nelec); - } - else if (PARAM.inp.esolver_type == "sdft") - { - diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type, - PARAM.inp.esolver_type, - PARAM.inp.calculation, - PARAM.inp.init_chg, - istep, - iter, - drho, - PARAM.inp.pw_diag_thr, - diag_ethr, - PARAM.inp.nbands, - esolver_KS_ne); - } - // 6) initialization of SCF iterations this->iter_init(istep, iter); @@ -610,33 +572,6 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) // 10) finish scf iterations this->iter_finish(istep, iter); -#ifdef __MPI - double duration = (double)(MPI_Wtime() - iterstart); -#else - double duration - = (std::chrono::duration_cast(std::chrono::system_clock::now() - iterstart)) - .count() - / static_cast(1e6); -#endif - - // 11) get mtaGGA related parameters - double dkin = 0.0; // for meta-GGA - if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) - { - dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec); - } - this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr); - - // 12) Json, need to be moved to somewhere else -#ifdef __RAPIDJSON - // add Json of scf mag - Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization, - GlobalC::ucell.magnet.abs_magnetization, - this->pelec->f_en.etot * ModuleBase::Ry_to_eV, - this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV, - drho, - duration); -#endif //__RAPIDJSON // 13) check convergence if (this->conv_esolver) @@ -644,12 +579,6 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) this->niter = iter; break; } - - // notice for restart - if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax) - { - std::cout << " SCF restart after this step!" << std::endl; - } } // end scf iterations std::cout << " >> Leave SCF iteration.\n * * * * * *" << std::endl; @@ -662,6 +591,47 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) return; }; +template +void ESolver_KS::iter_init(const int istep, const int iter) +{ + ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname); + +#ifdef __MPI + iter_time = MPI_Wtime(); +#else + iter_time = std::chrono::system_clock::now(); +#endif + + if (PARAM.inp.esolver_type == "ksdft") + { + diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type, + PARAM.inp.esolver_type, + PARAM.inp.calculation, + PARAM.inp.init_chg, + PARAM.inp.precision, + istep, + iter, + drho, + PARAM.inp.pw_diag_thr, + diag_ethr, + PARAM.inp.nelec); + } + else if (PARAM.inp.esolver_type == "sdft") + { + diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type, + PARAM.inp.esolver_type, + PARAM.inp.calculation, + PARAM.inp.init_chg, + istep, + iter, + drho, + PARAM.inp.pw_diag_thr, + diag_ethr, + PARAM.inp.nbands, + esolver_KS_ne); + } +} + template void ESolver_KS::iter_finish(const int istep, int& iter) { @@ -675,6 +645,39 @@ void ESolver_KS::iter_finish(const int istep, int& iter) } this->pelec->f_en.etot_delta = this->pelec->f_en.etot - this->pelec->f_en.etot_old; this->pelec->f_en.etot_old = this->pelec->f_en.etot; + +#ifdef __MPI + double duration = (double)(MPI_Wtime() - iter_time); +#else + double duration + = (std::chrono::duration_cast(std::chrono::system_clock::now() - iter_time)).count() + / static_cast(1e6); +#endif + + // get mtaGGA related parameters + double dkin = 0.0; // for meta-GGA + if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) + { + dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec); + } + this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr); + + // Json, need to be moved to somewhere else +#ifdef __RAPIDJSON + // add Json of scf mag + Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization, + GlobalC::ucell.magnet.abs_magnetization, + this->pelec->f_en.etot * ModuleBase::Ry_to_eV, + this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV, + drho, + duration); +#endif //__RAPIDJSON + + // notice for restart + if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax) + { + std::cout << " SCF restart after this step!" << std::endl; + } } //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. @@ -689,13 +692,6 @@ void ESolver_KS::after_scf(const int istep) { this->pelec->print_eigenvalue(GlobalV::ofs_running); } - // #ifdef __RAPIDJSON - // // add Json of efermi energy converge - // Json::add_output_efermi_converge(this->pelec->eferm.ef * ModuleBase::Ry_to_eV, this->conv_esolver); - // // add nkstot,nkstot_ibz to output json - // int Jnkstot = this->pelec->klist->get_nkstot(); - // Json::add_nkstot(Jnkstot); - // #endif //__RAPIDJSON } //------------------------------------------------------------------------------ diff --git a/source/module_esolver/esolver_ks.h b/source/module_esolver/esolver_ks.h index 5cc6ccbd55..2f615dab1c 100644 --- a/source/module_esolver/esolver_ks.h +++ b/source/module_esolver/esolver_ks.h @@ -18,70 +18,63 @@ namespace ModuleESolver template class ESolver_KS : public ESolver_FP { - public: + public: + //! Constructor + ESolver_KS(); - //! Constructor - ESolver_KS(); + //! Deconstructor + virtual ~ESolver_KS(); - //! Deconstructor - virtual ~ESolver_KS(); + virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override; - double scf_thr; // scf density threshold + virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09 - double scf_ene_thr; // scf energy threshold + virtual void runner(const int istep, UnitCell& cell) override; - double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver) + // calculate electron density from a specific Hamiltonian + virtual void hamilt2density(const int istep, const int iter, const double ethr); - int maxniter; // maximum iter steps for scf + // calculate electron states from a specific Hamiltonian + virtual void hamilt2estates(const double ethr) {}; - int niter; // iter steps actually used in scf + protected: + //! Something to do before SCF iterations. + virtual void before_scf(const int istep) {}; - int out_freq_elec; // frequency for output + //! Something to do before hamilt2density function in each iter loop. + virtual void iter_init(const int istep, const int iter); - virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override; + //! Something to do after hamilt2density function in each iter loop. + virtual void iter_finish(const int istep, int& iter); - virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09 + //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. + virtual void after_scf(const int istep) override; - virtual void runner(const int istep, UnitCell& cell) override; + //! It should be replaced by a function in Hamilt Class + virtual void update_pot(const int istep, const int iter) {}; - // calculate electron density from a specific Hamiltonian - virtual void hamilt2density(const int istep, const int iter, const double ethr); + //! Hamiltonian + hamilt::Hamilt* p_hamilt = nullptr; - // calculate electron states from a specific Hamiltonian - virtual void hamilt2estates(const double ethr){}; + ModulePW::PW_Basis_K* pw_wfc = nullptr; - protected: - //! Something to do before SCF iterations. - virtual void before_scf(const int istep) {}; + Charge_Mixing* p_chgmix = nullptr; - //! Something to do before hamilt2density function in each iter loop. - virtual void iter_init(const int istep, const int iter) {}; + wavefunc wf; - //! Something to do after hamilt2density function in each iter loop. - virtual void iter_finish(const int istep, int& iter); + // wavefunction coefficients + psi::Psi* psi = nullptr; - //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. - virtual void after_scf(const int istep) override; - - //! It should be replaced by a function in Hamilt Class - virtual void update_pot(const int istep, const int iter) {}; - - protected: - //! Hamiltonian - hamilt::Hamilt* p_hamilt = nullptr; - - ModulePW::PW_Basis_K* pw_wfc = nullptr; - - Charge_Mixing* p_chgmix = nullptr; - - wavefunc wf; - - // wavefunction coefficients - psi::Psi* psi = nullptr; - - protected: - std::string basisname; // PW or LCAO - double esolver_KS_ne = 0.0; + std::string basisname; // PW or LCAO + double esolver_KS_ne = 0.0; + double iter_time; // the start time of scf iteration + double diag_ethr; // the threshold for diagonalization + double scf_thr; // scf density threshold + double scf_ene_thr; // scf energy threshold + double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver) + int maxniter; // maximum iter steps for scf + int niter; // iter steps actually used in scf + int out_freq_elec; // frequency for output }; } // end of namespace #endif diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index b12fda5916..5bd2dca2f4 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -519,6 +519,9 @@ void ESolver_KS_LCAO::iter_init(const int istep, const int iter) { ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init"); + // call iter_init() of ESolver_KS + ESolver_KS::iter_init(istep, iter); + if (iter == 1) { this->p_chgmix->init_mixing(); // init mixing diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 17ec60b6bb..6815ecd89c 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -305,6 +305,9 @@ void ESolver_KS_PW::before_scf(const int istep) template void ESolver_KS_PW::iter_init(const int istep, const int iter) { + // call iter_init() of ESolver_KS + ESolver_KS::iter_init(istep, iter); + if (iter == 1) { this->p_chgmix->init_mixing(); From 40602d5b5c204b235ae770908c82c36b6148e757 Mon Sep 17 00:00:00 2001 From: YuLiu98 Date: Thu, 7 Nov 2024 10:12:48 +0800 Subject: [PATCH 2/3] update esolver.h --- source/module_esolver/esolver_ks.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/source/module_esolver/esolver_ks.h b/source/module_esolver/esolver_ks.h index 2f615dab1c..f249ec36b8 100644 --- a/source/module_esolver/esolver_ks.h +++ b/source/module_esolver/esolver_ks.h @@ -27,26 +27,26 @@ class ESolver_KS : public ESolver_FP virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override; - virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09 - virtual void runner(const int istep, UnitCell& cell) override; - // calculate electron density from a specific Hamiltonian - virtual void hamilt2density(const int istep, const int iter, const double ethr); - - // calculate electron states from a specific Hamiltonian - virtual void hamilt2estates(const double ethr) {}; - protected: //! Something to do before SCF iterations. virtual void before_scf(const int istep) {}; + virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09 + //! Something to do before hamilt2density function in each iter loop. virtual void iter_init(const int istep, const int iter); //! Something to do after hamilt2density function in each iter loop. virtual void iter_finish(const int istep, int& iter); + // calculate electron density from a specific Hamiltonian + virtual void hamilt2density(const int istep, const int iter, const double ethr); + + // calculate electron states from a specific Hamiltonian + virtual void hamilt2estates(const double ethr) {}; + //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. virtual void after_scf(const int istep) override; From d5c0bd230a513f26dca889c83f96032f0d47bb8f Mon Sep 17 00:00:00 2001 From: YuLiu98 Date: Thu, 7 Nov 2024 10:54:14 +0800 Subject: [PATCH 3/3] fix some errors --- source/module_esolver/esolver_ks.cpp | 10 +++------- source/module_esolver/esolver_ks.h | 9 +++++++++ source/module_esolver/esolver_ks_pw.cpp | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index e8a7479bbf..45b228ed6c 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -1,12 +1,5 @@ #include "esolver_ks.h" -#include -#include -#ifdef __MPI -#include -#else -#include -#endif #include "module_base/timer.h" #include "module_cell/cal_atoms_info.h" #include "module_io/json_output/init_info.h" @@ -15,6 +8,9 @@ #include "module_io/print_info.h" #include "module_io/write_istate_info.h" #include "module_parameter/parameter.h" + +#include +#include //--------------Temporary---------------- #include "module_base/global_variable.h" #include "module_hamilt_lcao/module_dftu/dftu.h" diff --git a/source/module_esolver/esolver_ks.h b/source/module_esolver/esolver_ks.h index 47ae8c6583..8b952a6e22 100644 --- a/source/module_esolver/esolver_ks.h +++ b/source/module_esolver/esolver_ks.h @@ -10,6 +10,11 @@ #include "module_io/cal_test.h" #include "module_psi/psi.h" +#ifdef __MPI +#include +#else +#include +#endif #include #include namespace ModuleESolver @@ -68,7 +73,11 @@ class ESolver_KS : public ESolver_FP std::string basisname; // PW or LCAO double esolver_KS_ne = 0.0; bool oscillate_esolver = false; // whether esolver is oscillated +#ifdef __MPI double iter_time; // the start time of scf iteration +#else + std::chrono::system_clock::time_point iter_time; +#endif double diag_ethr; // the threshold for diagonalization double scf_thr; // scf density threshold double scf_ene_thr; // scf energy threshold diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 6815ecd89c..91bc8c2618 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -306,7 +306,7 @@ template void ESolver_KS_PW::iter_init(const int istep, const int iter) { // call iter_init() of ESolver_KS - ESolver_KS::iter_init(istep, iter); + ESolver_KS::iter_init(istep, iter); if (iter == 1) {