diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index da4c4c2c3c..45b228ed6c 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -1,12 +1,5 @@ #include "esolver_ks.h" -#include -#include -#ifdef __MPI -#include -#else -#include -#endif #include "module_base/timer.h" #include "module_cell/cal_atoms_info.h" #include "module_io/json_output/init_info.h" @@ -15,6 +8,9 @@ #include "module_io/print_info.h" #include "module_io/write_istate_info.h" #include "module_parameter/parameter.h" + +#include +#include //--------------Temporary---------------- #include "module_base/global_variable.h" #include "module_hamilt_lcao/module_dftu/dftu.h" @@ -427,49 +423,11 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) this->niter = this->maxniter; // 4) SCF iterations - double diag_ethr = PARAM.inp.pw_diag_thr; + this->diag_ethr = PARAM.inp.pw_diag_thr; std::cout << " * * * * * *\n << Start SCF iteration." << std::endl; for (int iter = 1; iter <= this->maxniter; ++iter) { - // 5) write head - ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname); - -#ifdef __MPI - auto iterstart = MPI_Wtime(); -#else - auto iterstart = std::chrono::system_clock::now(); -#endif - - if (PARAM.inp.esolver_type == "ksdft") - { - diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type, - PARAM.inp.esolver_type, - PARAM.inp.calculation, - PARAM.inp.init_chg, - PARAM.inp.precision, - istep, - iter, - drho, - PARAM.inp.pw_diag_thr, - diag_ethr, - PARAM.inp.nelec); - } - else if (PARAM.inp.esolver_type == "sdft") - { - diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type, - PARAM.inp.esolver_type, - PARAM.inp.calculation, - PARAM.inp.init_chg, - istep, - iter, - drho, - PARAM.inp.pw_diag_thr, - diag_ethr, - PARAM.inp.nbands, - esolver_KS_ne); - } - // 6) initialization of SCF iterations this->iter_init(istep, iter); @@ -615,33 +573,6 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) // 10) finish scf iterations this->iter_finish(istep, iter); -#ifdef __MPI - double duration = (double)(MPI_Wtime() - iterstart); -#else - double duration - = (std::chrono::duration_cast(std::chrono::system_clock::now() - iterstart)) - .count() - / static_cast(1e6); -#endif - - // 11) get mtaGGA related parameters - double dkin = 0.0; // for meta-GGA - if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) - { - dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec); - } - this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr); - - // 12) Json, need to be moved to somewhere else -#ifdef __RAPIDJSON - // add Json of scf mag - Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization, - GlobalC::ucell.magnet.abs_magnetization, - this->pelec->f_en.etot * ModuleBase::Ry_to_eV, - this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV, - drho, - duration); -#endif //__RAPIDJSON // 13) check convergence if (this->conv_esolver || this->oscillate_esolver) @@ -653,12 +584,6 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) } break; } - - // notice for restart - if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax) - { - std::cout << " SCF restart after this step!" << std::endl; - } } // end scf iterations std::cout << " >> Leave SCF iteration.\n * * * * * *" << std::endl; @@ -671,6 +596,47 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) return; }; +template +void ESolver_KS::iter_init(const int istep, const int iter) +{ + ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname); + +#ifdef __MPI + iter_time = MPI_Wtime(); +#else + iter_time = std::chrono::system_clock::now(); +#endif + + if (PARAM.inp.esolver_type == "ksdft") + { + diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type, + PARAM.inp.esolver_type, + PARAM.inp.calculation, + PARAM.inp.init_chg, + PARAM.inp.precision, + istep, + iter, + drho, + PARAM.inp.pw_diag_thr, + diag_ethr, + PARAM.inp.nelec); + } + else if (PARAM.inp.esolver_type == "sdft") + { + diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type, + PARAM.inp.esolver_type, + PARAM.inp.calculation, + PARAM.inp.init_chg, + istep, + iter, + drho, + PARAM.inp.pw_diag_thr, + diag_ethr, + PARAM.inp.nbands, + esolver_KS_ne); + } +} + template void ESolver_KS::iter_finish(const int istep, int& iter) { @@ -684,6 +650,39 @@ void ESolver_KS::iter_finish(const int istep, int& iter) } this->pelec->f_en.etot_delta = this->pelec->f_en.etot - this->pelec->f_en.etot_old; this->pelec->f_en.etot_old = this->pelec->f_en.etot; + +#ifdef __MPI + double duration = (double)(MPI_Wtime() - iter_time); +#else + double duration + = (std::chrono::duration_cast(std::chrono::system_clock::now() - iter_time)).count() + / static_cast(1e6); +#endif + + // get mtaGGA related parameters + double dkin = 0.0; // for meta-GGA + if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) + { + dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec); + } + this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr); + + // Json, need to be moved to somewhere else +#ifdef __RAPIDJSON + // add Json of scf mag + Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization, + GlobalC::ucell.magnet.abs_magnetization, + this->pelec->f_en.etot * ModuleBase::Ry_to_eV, + this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV, + drho, + duration); +#endif //__RAPIDJSON + + // notice for restart + if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax) + { + std::cout << " SCF restart after this step!" << std::endl; + } } //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. @@ -698,13 +697,6 @@ void ESolver_KS::after_scf(const int istep) { this->pelec->print_eigenvalue(GlobalV::ofs_running); } - // #ifdef __RAPIDJSON - // // add Json of efermi energy converge - // Json::add_output_efermi_converge(this->pelec->eferm.ef * ModuleBase::Ry_to_eV, this->conv_esolver); - // // add nkstot,nkstot_ibz to output json - // int Jnkstot = this->pelec->klist->get_nkstot(); - // Json::add_nkstot(Jnkstot); - // #endif //__RAPIDJSON } //------------------------------------------------------------------------------ diff --git a/source/module_esolver/esolver_ks.h b/source/module_esolver/esolver_ks.h index bcdb9abc94..8b952a6e22 100644 --- a/source/module_esolver/esolver_ks.h +++ b/source/module_esolver/esolver_ks.h @@ -10,79 +10,81 @@ #include "module_io/cal_test.h" #include "module_psi/psi.h" -#include +#ifdef __MPI +#include +#else +#include +#endif #include +#include namespace ModuleESolver { template class ESolver_KS : public ESolver_FP { - public: - - //! Constructor - ESolver_KS(); - - //! Deconstructor - virtual ~ESolver_KS(); - - double scf_thr; // scf density threshold + public: + //! Constructor + ESolver_KS(); - double scf_ene_thr; // scf energy threshold + //! Deconstructor + virtual ~ESolver_KS(); - double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver) + virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override; - int maxniter; // maximum iter steps for scf + virtual void runner(const int istep, UnitCell& cell) override; - int niter; // iter steps actually used in scf + protected: + //! Something to do before SCF iterations. + virtual void before_scf(const int istep) {}; - int out_freq_elec; // frequency for output + virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09 - virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override; + //! Something to do before hamilt2density function in each iter loop. + virtual void iter_init(const int istep, const int iter); - virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09 + //! Something to do after hamilt2density function in each iter loop. + virtual void iter_finish(const int istep, int& iter); - virtual void runner(const int istep, UnitCell& cell) override; + // calculate electron density from a specific Hamiltonian + virtual void hamilt2density(const int istep, const int iter, const double ethr); - // calculate electron density from a specific Hamiltonian - virtual void hamilt2density(const int istep, const int iter, const double ethr); + // calculate electron states from a specific Hamiltonian + virtual void hamilt2estates(const double ethr) {}; - // calculate electron states from a specific Hamiltonian - virtual void hamilt2estates(const double ethr){}; + //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. + virtual void after_scf(const int istep) override; - protected: - //! Something to do before SCF iterations. - virtual void before_scf(const int istep) {}; + //! It should be replaced by a function in Hamilt Class + virtual void update_pot(const int istep, const int iter) {}; - //! Something to do before hamilt2density function in each iter loop. - virtual void iter_init(const int istep, const int iter) {}; + //! Hamiltonian + hamilt::Hamilt* p_hamilt = nullptr; - //! Something to do after hamilt2density function in each iter loop. - virtual void iter_finish(const int istep, int& iter); + ModulePW::PW_Basis_K* pw_wfc = nullptr; - //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. - virtual void after_scf(const int istep) override; + Charge_Mixing* p_chgmix = nullptr; - //! It should be replaced by a function in Hamilt Class - virtual void update_pot(const int istep, const int iter) {}; + wavefunc wf; - protected: - //! Hamiltonian - hamilt::Hamilt* p_hamilt = nullptr; + // wavefunction coefficients + psi::Psi* psi = nullptr; - ModulePW::PW_Basis_K* pw_wfc = nullptr; - - Charge_Mixing* p_chgmix = nullptr; - - wavefunc wf; - - // wavefunction coefficients - psi::Psi* psi = nullptr; - - protected: - std::string basisname; // PW or LCAO - double esolver_KS_ne = 0.0; - bool oscillate_esolver = false; // whether esolver is oscillated -}; -} // end of namespace + std::string basisname; // PW or LCAO + double esolver_KS_ne = 0.0; + bool oscillate_esolver = false; // whether esolver is oscillated +#ifdef __MPI + double iter_time; // the start time of scf iteration +#else + std::chrono::system_clock::time_point iter_time; +#endif + double diag_ethr; // the threshold for diagonalization + double scf_thr; // scf density threshold + double scf_ene_thr; // scf energy threshold + double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver) + int maxniter; // maximum iter steps for scf + int niter; // iter steps actually used in scf + int out_freq_elec; // frequency for output +}; +} // namespace ModuleESolver #endif diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index b12fda5916..5bd2dca2f4 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -519,6 +519,9 @@ void ESolver_KS_LCAO::iter_init(const int istep, const int iter) { ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init"); + // call iter_init() of ESolver_KS + ESolver_KS::iter_init(istep, iter); + if (iter == 1) { this->p_chgmix->init_mixing(); // init mixing diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 17ec60b6bb..91bc8c2618 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -305,6 +305,9 @@ void ESolver_KS_PW::before_scf(const int istep) template void ESolver_KS_PW::iter_init(const int istep, const int iter) { + // call iter_init() of ESolver_KS + ESolver_KS::iter_init(istep, iter); + if (iter == 1) { this->p_chgmix->init_mixing();