From 42c6cf2a9248ba723fa662e976f53598042ae5b8 Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Wed, 9 Jul 2025 20:41:07 +0800 Subject: [PATCH 1/5] Refactor: Replace PARAM.inp with inp in ESolver classes for consistency --- source/source_esolver/esolver_fp.cpp | 18 ++++---- source/source_esolver/esolver_ks.cpp | 40 ++++++++--------- source/source_esolver/esolver_ks_lcao.cpp | 48 ++++++++++----------- source/source_esolver/esolver_ks_lcaopw.cpp | 6 +-- source/source_esolver/esolver_ks_pw.cpp | 24 +++++------ source/source_esolver/esolver_of.cpp | 10 ++--- source/source_esolver/esolver_sdft_pw.cpp | 2 +- 7 files changed, 74 insertions(+), 74 deletions(-) diff --git a/source/source_esolver/esolver_fp.cpp b/source/source_esolver/esolver_fp.cpp index 4d4b522bfe..97680ba8d0 100644 --- a/source/source_esolver/esolver_fp.cpp +++ b/source/source_esolver/esolver_fp.cpp @@ -47,18 +47,18 @@ ESolver_FP::~ESolver_FP() void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp) { ModuleBase::TITLE("ESolver_FP", "before_all_runners"); - std::string fft_device = PARAM.inp.device; - std::string fft_precison = PARAM.inp.precision; + std::string fft_device = inp.device; + std::string fft_precison = inp.precision; // LCAO basis doesn't support GPU acceleration on FFT currently - if(PARAM.inp.basis_type == "lcao") + if(inp.basis_type == "lcao") { fft_device = "cpu"; } - if ((PARAM.inp.precision=="single") || (PARAM.inp.precision=="mixing")) + if ((inp.precision=="single") || (inp.precision=="mixing")) { fft_precison = "mixing"; } - else if (PARAM.inp.precision=="double") + else if (inp.precision=="double") { fft_precison = "double"; } @@ -79,8 +79,8 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp) pw_rhod = pw_rho; } pw_big = static_cast(pw_rhod); - pw_big->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz); - sf.set(pw_rhod, PARAM.inp.nbspline); + pw_big->setbxyz(inp.bx, inp.by, inp.bz); + sf.set(pw_rhod, inp.nbspline); //! 1) read pseudopotentials elecstate::read_pseudo(GlobalV::ofs_running, ucell); @@ -89,7 +89,7 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp) #ifdef __MPI this->pw_rho->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD); #endif - if (this->classname == "ESolver_OF" || PARAM.inp.of_ml_gene_data == 1) + if (this->classname == "ESolver_OF" || inp.of_ml_gene_data == 1) { this->pw_rho->setfullpw(inp.of_full_pw, inp.of_full_pw_dim); } @@ -143,7 +143,7 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp) ModuleIO::print_rhofft(this->pw_rhod, this->pw_rho, this->pw_big, GlobalV::ofs_running); //! 5) initialize the charge extrapolation method if necessary - this->CE.Init_CE(PARAM.inp.nspin, ucell.nat, this->pw_rhod->nrxx, inp.chg_extrap); + this->CE.Init_CE(inp.nspin, ucell.nat, this->pw_rhod->nrxx, inp.chg_extrap); return; } diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 000f13f43b..43e22f1888 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -57,23 +57,23 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para classname = "ESolver_KS"; basisname = ""; - scf_thr = PARAM.inp.scf_thr; - scf_ene_thr = PARAM.inp.scf_ene_thr; - maxniter = PARAM.inp.scf_nmax; + scf_thr = inp.scf_thr; + scf_ene_thr = inp.scf_ene_thr; + maxniter = inp.scf_nmax; niter = maxniter; drho = 0.0; - std::string fft_device = PARAM.inp.device; + std::string fft_device = inp.device; // Fast Fourier Transform // LCAO basis doesn't support GPU acceleration on FFT currently - if(PARAM.inp.basis_type == "lcao") + if(inp.basis_type == "lcao") { fft_device = "cpu"; } - std::string fft_precision = PARAM.inp.precision; + std::string fft_precision = inp.precision; #ifdef __ENABLE_FLOAT_FFTW - if (PARAM.inp.cal_cond && PARAM.inp.esolver_type == "sdft") + if (inp.cal_cond && inp.esolver_type == "sdft") { fft_precision = "mixing"; } @@ -83,7 +83,7 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para ModulePW::PW_Basis_K_Big* tmp = static_cast(pw_wfc); // should not use INPUT here, mohan 2024-05-12 - tmp->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz); + tmp->setbxyz(inp.bx, inp.by, inp.bz); ///---------------------------------------------------------- /// charge mixing @@ -92,7 +92,7 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para p_chgmix->set_rhopw(this->pw_rho, this->pw_rhod); // cell_factor - this->ppcell.cell_factor = PARAM.inp.cell_factor; + this->ppcell.cell_factor = inp.cell_factor; //! 3) it has been established that @@ -103,16 +103,16 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL"); //! 4) setup the charge mixing parameters - p_chgmix->set_mixing(PARAM.inp.mixing_mode, - PARAM.inp.mixing_beta, - PARAM.inp.mixing_ndim, - PARAM.inp.mixing_gg0, - PARAM.inp.mixing_tau, - PARAM.inp.mixing_beta_mag, - PARAM.inp.mixing_gg0_mag, - PARAM.inp.mixing_gg0_min, - PARAM.inp.mixing_angle, - PARAM.inp.mixing_dmr, + p_chgmix->set_mixing(inp.mixing_mode, + inp.mixing_beta, + inp.mixing_ndim, + inp.mixing_gg0, + inp.mixing_tau, + inp.mixing_beta_mag, + inp.mixing_gg0_mag, + inp.mixing_gg0_min, + inp.mixing_angle, + inp.mixing_dmr, ucell.omega, ucell.tpiba); @@ -127,7 +127,7 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para } //! 6) Setup the k points according to symmetry. - this->kv.set(ucell,ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); + this->kv.set(ucell,ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS"); //! 7) print information diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 4bccc580b4..e9d3e83090 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -139,11 +139,11 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa int ncol = 0; if (PARAM.globalv.gamma_only_local) { - nsk = PARAM.inp.nspin; + nsk = inp.nspin; ncol = this->pv.ncol_bands; - if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver == "lapack" - || PARAM.inp.ks_solver == "pexsi" || PARAM.inp.ks_solver == "cusolver" - || PARAM.inp.ks_solver == "cusolvermp") + if (inp.ks_solver == "genelpa" || inp.ks_solver == "elpa" || inp.ks_solver == "lapack" + || inp.ks_solver == "pexsi" || inp.ks_solver == "cusolver" + || inp.ks_solver == "cusolvermp") { ncol = this->pv.ncol; } @@ -154,14 +154,14 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa #ifdef __MPI ncol = this->pv.ncol_bands; #else - ncol = PARAM.inp.nbands; + ncol = inp.nbands; #endif } this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // 5) read psi from file - if (PARAM.inp.init_wfc == "file") + if (inp.init_wfc == "file") { if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, this->pv, @@ -169,7 +169,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa this->pelec, this->pelec->klist->ik2iktot, this->pelec->klist->get_nkstot(), - PARAM.inp.nspin)) + inp.nspin)) { ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "read electronic wave functions failed"); } @@ -178,16 +178,16 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa // 6) initialize the density matrix // DensityMatrix is allocated here, DMK is also initialized here // DMR is not initialized here, it will be constructed in each before_scf - dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), PARAM.inp.nspin); + dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin); // 7) initialize exact exchange calculations #ifdef __EXX - if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax" - || PARAM.inp.calculation == "md") + if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" + || inp.calculation == "md") { if (GlobalC::exx_info.info_global.cal_exx) { - if (PARAM.inp.init_wfc != "file") + if (inp.init_wfc != "file") { // if init_wfc==file, directly enter the EXX loop XC_Functional::set_xc_first_loop(ucell); } @@ -208,7 +208,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa #endif // 8) initialize DFT+U - if (PARAM.inp.dft_plus_u) + if (inp.dft_plus_u) { auto* dftu = ModuleDFTU::DFTU::get_instance(); dftu->init(ucell, &this->pv, this->kv.get_nks(), &orb_); @@ -219,7 +219,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL"); // 10) inititlize the charge density - this->chr.allocate(PARAM.inp.nspin); + this->chr.allocate(inp.nspin); this->pelec->omega = ucell.omega; // 11) initialize the potential @@ -238,13 +238,13 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa // 12) initialize deepks #ifdef __MLALGO LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld, GlobalV::ofs_running); - if (PARAM.inp.deepks_scf) + if (inp.deepks_scf) { // load the DeePKS model from deep neural network - DeePKS_domain::load_model(PARAM.inp.deepks_model, ld.model_deepks); + DeePKS_domain::load_model(inp.deepks_model, ld.model_deepks); // read pdm from file for NSCF or SCF-restart, do it only once in whole calculation - DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"), - PARAM.inp.deepks_equiv, + DeePKS_domain::read_pdm((inp.init_chg == "file"), + inp.deepks_equiv, ld.init_pdm, ucell.nat, orb_.Alpha[0].getTotal_nchi() * ucell.nat, @@ -257,11 +257,11 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa // 13) set occupations // tddft does not need to set occupations in the first scf - if (PARAM.inp.ocp && inp.esolver_type != "tddft") + if (inp.ocp && inp.esolver_type != "tddft") { - elecstate::fixed_weights(PARAM.inp.ocp_kb, - PARAM.inp.nbands, - PARAM.inp.nelec, + elecstate::fixed_weights(inp.ocp_kb, + inp.nbands, + inp.nelec, this->pelec->klist, this->pelec->wg, this->pelec->skip_weights); @@ -289,7 +289,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } // 15) initialize rdmft, added by jghan - if (PARAM.inp.rdmft == true) + if (inp.rdmft == true) { rdmft_solver.init(this->GG, this->GK, @@ -300,8 +300,8 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa *(this->pelec), this->orb_, two_center_bundle_, - PARAM.inp.dft_functional, - PARAM.inp.rdmft_power_alpha); + inp.dft_functional, + inp.rdmft_power_alpha); } ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners"); diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 4f1690d2f2..0de6f10a2b 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -93,9 +93,9 @@ namespace ModuleESolver this->kv.ngk, true); #ifdef __EXX - if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" - || PARAM.inp.calculation == "cell-relax" - || PARAM.inp.calculation == "md") { + if (inp.calculation == "scf" || inp.calculation == "relax" + || inp.calculation == "cell-relax" + || inp.calculation == "md") { if (GlobalC::exx_info.info_global.cal_exx) { XC_Functional::set_xc_first_loop(ucell); diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 760a597d1c..acc752d0cf 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -169,7 +169,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->pelec->omega = ucell.omega; //! 3) inititlize the charge density. - this->chr.allocate(PARAM.inp.nspin); + this->chr.allocate(inp.nspin); //! 4) initialize the potential. if (this->pelec->pot == nullptr) @@ -194,9 +194,9 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); //! 7) Allocate and initialize psi - this->p_psi_init = new psi::PSIInit(PARAM.inp.init_wfc, - PARAM.inp.ks_solver, - PARAM.inp.basis_type, + this->p_psi_init = new psi::PSIInit(inp.init_wfc, + inp.ks_solver, + inp.basis_type, GlobalV::MY_RANK, ucell, this->sf, @@ -206,28 +206,28 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); - this->p_psi_init->prepare_init(PARAM.inp.pw_seed); + this->p_psi_init->prepare_init(inp.pw_seed); - this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single" + this->kspw_psi = inp.device == "gpu" || inp.precision == "single" ? new psi::Psi(this->psi[0]) : reinterpret_cast*>(this->psi); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); //! 8) setup occupations - if (PARAM.inp.ocp) + if (inp.ocp) { - elecstate::fixed_weights(PARAM.inp.ocp_kb, - PARAM.inp.nbands, - PARAM.inp.nelec, + elecstate::fixed_weights(inp.ocp_kb, + inp.nbands, + inp.nelec, this->pelec->klist, this->pelec->wg, this->pelec->skip_weights); } // 9) initialize exx pw - if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax" - || PARAM.inp.calculation == "md") + if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" + || inp.calculation == "md") { if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true) { diff --git a/source/source_esolver/esolver_of.cpp b/source/source_esolver/esolver_of.cpp index cc0ed32e0b..1ece91eb62 100644 --- a/source/source_esolver/esolver_of.cpp +++ b/source/source_esolver/esolver_of.cpp @@ -87,7 +87,7 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp) } // Setup the k points according to symmetry. - kv.set(ucell,ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); + kv.set(ucell,ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS"); // print information @@ -127,12 +127,12 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp) // Initialize KEDF // Calculate electron numbers, which will be used to initialize WT KEDF - this->nelec_ = new double[PARAM.inp.nspin]; - if (PARAM.inp.nspin == 1) + this->nelec_ = new double[inp.nspin]; + if (inp.nspin == 1) { - this->nelec_[0] = PARAM.inp.nelec; + this->nelec_[0] = inp.nelec; } - else if (PARAM.inp.nspin == 2) + else if (inp.nspin == 2) { // in fact, nelec_spin will not be used anymore this->pelec->init_nelec_spin(); diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index fdfaf1700e..42fde3ebe5 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -72,7 +72,7 @@ void ESolver_SDFT_PW::before_all_runners(UnitCell& ucell, const Input true); ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T)); - if (PARAM.inp.nbands > 0) + if (inp.nbands > 0) { this->stowf.chiortho = new psi::Psi(this->kv.get_nks(), From f4f81e3902315e4446a55094ee93ff771eb68855 Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Wed, 9 Jul 2025 20:59:00 +0800 Subject: [PATCH 2/5] Refactor: Replace local input parameters with PARAM.inp in ESolver classes for consistency --- source/source_esolver/esolver_ks.cpp | 15 +++++-------- source/source_esolver/esolver_ks.h | 3 --- source/source_esolver/esolver_of.cpp | 22 +++++++------------ source/source_esolver/esolver_of.h | 8 ------- .../source_esolver/esolver_of_interface.cpp | 14 ++++++------ 5 files changed, 21 insertions(+), 41 deletions(-) diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 43e22f1888..55c4274b95 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -57,10 +57,7 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para classname = "ESolver_KS"; basisname = ""; - scf_thr = inp.scf_thr; - scf_ene_thr = inp.scf_ene_thr; - maxniter = inp.scf_nmax; - niter = maxniter; + niter = inp.scf_nmax; drho = 0.0; std::string fft_device = inp.device; @@ -235,9 +232,9 @@ void ESolver_KS::runner(UnitCell& ucell, const int istep) // 2) SCF iterations //---------------------------------------------------------------- bool conv_esolver = false; - this->niter = this->maxniter; + this->niter = PARAM.inp.scf_nmax; this->diag_ethr = PARAM.inp.pw_diag_thr; - for (int iter = 1; iter <= this->maxniter; ++iter) + for (int iter = 1; iter <= PARAM.inp.scf_nmax; ++iter) { //---------------------------------------------------------------- // 3) initialization of SCF iterations @@ -398,10 +395,10 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i } #endif - conv_esolver = (drho < this->scf_thr && not_restart_step && is_U_converged); + conv_esolver = (drho < PARAM.inp.scf_thr && not_restart_step && is_U_converged); // add energy threshold for SCF convergence - if (this->scf_ene_thr > 0.0) + if (PARAM.inp.scf_ene_thr > 0.0) { // calculate energy of output charge density this->update_pot(ucell, istep, iter, conv_esolver); @@ -415,7 +412,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i { // update the convergence flag conv_esolver - = (std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < this->scf_ene_thr); + = (std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < PARAM.inp.scf_ene_thr); } } diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index ef26d986c6..3d940328a2 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -76,11 +76,8 @@ class ESolver_KS : public ESolver_FP std::string basisname; //! esolver_ks_lcao.cpp double esolver_KS_ne = 0.0; //! number of electrons double diag_ethr; //! the threshold for diagonalization - double scf_thr; //! scf density threshold - double scf_ene_thr; //! scf energy threshold double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver) double hsolver_error; //! the error of HSolver - int maxniter; //! maximum iter steps for scf int niter; //! iter steps actually used in scf bool oscillate_esolver = false; // whether esolver is oscillated }; diff --git a/source/source_esolver/esolver_of.cpp b/source/source_esolver/esolver_of.cpp index 1ece91eb62..3914eb2957 100644 --- a/source/source_esolver/esolver_of.cpp +++ b/source/source_esolver/esolver_of.cpp @@ -60,12 +60,6 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp) ESolver_FP::before_all_runners(ucell, inp); // save necessary parameters - this->of_kinetic_ = inp.of_kinetic; - this->of_method_ = inp.of_method; - this->of_conv_ = inp.of_conv; - this->of_tole_ = inp.of_tole; - this->of_tolp_ = inp.of_tolp; - this->max_iter_ = inp.scf_nmax; this->dV_ = ucell.omega / this->pw_rho->nxyz; this->bound_cal_potential_ = std::bind(&ESolver_OF::cal_potential, this, std::placeholders::_1, std::placeholders::_2, std::ref(ucell)); @@ -422,7 +416,7 @@ void ESolver_OF::update_rho() } /** - * @brief Check convergence, return ture if converge or iter >= max_iter_, + * @brief Check convergence, return ture if converge or iter >= PARAM.inp.scf_nmax, * and print the necessary information * * @return exit or not @@ -434,7 +428,7 @@ bool ESolver_OF::check_exit(bool& conv_esolver) bool potHold = false; // if normdLdphi nearly remains unchanged bool energyConv = false; - if (this->normdLdphi_ < this->of_tolp_) + if (this->normdLdphi_ < PARAM.inp.of_tolp) { potConv = true; } @@ -444,23 +438,23 @@ bool ESolver_OF::check_exit(bool& conv_esolver) potHold = true; } - if (this->iter_ >= 3 && std::abs(this->energy_current_ - this->energy_last_) < this->of_tole_ - && std::abs(this->energy_current_ - this->energy_llast_) < this->of_tole_) + if (this->iter_ >= 3 && std::abs(this->energy_current_ - this->energy_last_) < PARAM.inp.of_tole + && std::abs(this->energy_current_ - this->energy_llast_) < PARAM.inp.of_tole) { energyConv = true; } - conv_esolver = (this->of_conv_ == "energy" && energyConv) || (this->of_conv_ == "potential" && potConv) - || (this->of_conv_ == "both" && potConv && energyConv); + conv_esolver = (PARAM.inp.of_conv == "energy" && energyConv) || (PARAM.inp.of_conv == "potential" && potConv) + || (PARAM.inp.of_conv == "both" && potConv && energyConv); this->print_info(conv_esolver); - if (conv_esolver || this->iter_ >= this->max_iter_) + if (conv_esolver || this->iter_ >= PARAM.inp.scf_nmax) { return true; } // ============ temporary solution of potential convergence =========== - else if (this->of_conv_ == "potential" && potHold) + else if (PARAM.inp.of_conv == "potential" && potHold) { GlobalV::ofs_warning << "ESolver_OF WARNING: " << "The convergence of potential has not been reached, but the norm of potential nearly " diff --git a/source/source_esolver/esolver_of.h b/source/source_esolver/esolver_of.h index 508a18b3db..332eb30139 100644 --- a/source/source_esolver/esolver_of.h +++ b/source/source_esolver/esolver_of.h @@ -38,14 +38,6 @@ class ESolver_OF : public ESolver_FP ModuleBase::Opt_DCsrch* opt_dcsrch_ = nullptr; ModuleBase::Opt_CG* opt_cg_mag_ = nullptr; // for spin2 case, under testing - // ----------------- necessary parameters from INPUT ------------ - std::string of_kinetic_ = "wt"; // Kinetic energy functional, such as TF, VW, WT - std::string of_method_ = "tn"; // optimization method, include cg1, cg2, tn (default), bfgs - std::string of_conv_ = "energy"; // select the convergence criterion, potential, energy (default), or both - double of_tole_ = 2e-6; // tolerance of the energy change (in Ry) for determining the convergence, default=2e-6 Ry - double of_tolp_ = 1e-5; // tolerance of potential for determining the convergence, default=1e-5 in a.u. - int max_iter_ = 50; // scf_nmax - // ------------------ parameters from other module -------------- double dV_ = 0; // volume of one grid point in real space double* nelec_ = nullptr; // number of electrons with each spin diff --git a/source/source_esolver/esolver_of_interface.cpp b/source/source_esolver/esolver_of_interface.cpp index 7051308a00..48755aecf0 100644 --- a/source/source_esolver/esolver_of_interface.cpp +++ b/source/source_esolver/esolver_of_interface.cpp @@ -16,7 +16,7 @@ void ESolver_OF::init_opt() this->opt_dcsrch_ = new ModuleBase::Opt_DCsrch(); } - if (this->of_method_ == "tn") + if (PARAM.inp.of_method == "tn") { if (this->opt_tn_ == nullptr) { @@ -25,7 +25,7 @@ void ESolver_OF::init_opt() this->opt_tn_->allocate(this->pw_rho->nrxx); this->opt_tn_->set_para(this->dV_); } - else if (this->of_method_ == "cg1" || this->of_method_ == "cg2") + else if (PARAM.inp.of_method == "cg1" || PARAM.inp.of_method == "cg2") { if (this->opt_cg_ == nullptr) { @@ -35,7 +35,7 @@ void ESolver_OF::init_opt() this->opt_cg_->set_para(this->dV_); this->opt_dcsrch_->set_paras(1e-4, 1e-2); } - else if (this->of_method_ == "bfgs") + else if (PARAM.inp.of_method == "bfgs") { ModuleBase::WARNING_QUIT("esolver_of", "BFGS is not supported now."); return; @@ -62,7 +62,7 @@ void ESolver_OF::get_direction(UnitCell& ucell) { for (int is = 0; is < PARAM.inp.nspin; ++is) { - if (this->of_method_ == "tn") + if (PARAM.inp.of_method == "tn") { this->tn_spin_flag_ = is; opt_tn_->next_direct(this->pphi_[is], @@ -72,15 +72,15 @@ void ESolver_OF::get_direction(UnitCell& ucell) this, &ESolver_OF::cal_potential_wrapper); } - else if (this->of_method_ == "cg1") + else if (PARAM.inp.of_method == "cg1") { opt_cg_->next_direct(this->pdLdphi_[is], 1, this->pdirect_[is]); } - else if (this->of_method_ == "cg2") + else if (PARAM.inp.of_method == "cg2") { opt_cg_->next_direct(this->pdLdphi_[is], 2, this->pdirect_[is]); } - else if (this->of_method_ == "bfgs") + else if (PARAM.inp.of_method == "bfgs") { return; } From b1bd0fda46630cd0acddc392e8ca8f7a22918c4f Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Wed, 9 Jul 2025 21:15:35 +0800 Subject: [PATCH 3/5] Refactor: Use PARAM.inp.scf_ene_thr in ESolver_KS_LCAO iter_finish method --- source/source_esolver/esolver_ks_lcao.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index e9d3e83090..4280ff9f40 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -853,7 +853,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& *this->p_hamilt, *this->pelec, *this->p_chgmix, - this->scf_ene_thr, + PARAM.inp.scf_ene_thr, iter, istep, conv_esolver) @@ -862,7 +862,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& *this->p_hamilt, *this->pelec, *this->p_chgmix, - this->scf_ene_thr, + PARAM.inp.scf_ene_thr, iter, istep, conv_esolver); From ba989f1249bdc85aea8237176afe7626a7dfd153 Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Thu, 10 Jul 2025 11:52:40 +0800 Subject: [PATCH 4/5] Revert "Refactor: Use PARAM.inp.scf_ene_thr in ESolver_KS_LCAO iter_finish method" This reverts commit b1bd0fda46630cd0acddc392e8ca8f7a22918c4f. --- source/source_esolver/esolver_ks_lcao.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 4280ff9f40..e9d3e83090 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -853,7 +853,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& *this->p_hamilt, *this->pelec, *this->p_chgmix, - PARAM.inp.scf_ene_thr, + this->scf_ene_thr, iter, istep, conv_esolver) @@ -862,7 +862,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& *this->p_hamilt, *this->pelec, *this->p_chgmix, - PARAM.inp.scf_ene_thr, + this->scf_ene_thr, iter, istep, conv_esolver); From 758516a9fbb27c719b901cc1c17c5d521c7fdbe7 Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Thu, 10 Jul 2025 11:52:58 +0800 Subject: [PATCH 5/5] Revert "Refactor: Replace local input parameters with PARAM.inp in ESolver classes for consistency" This reverts commit f4f81e3902315e4446a55094ee93ff771eb68855. --- source/source_esolver/esolver_ks.cpp | 15 ++++++++----- source/source_esolver/esolver_ks.h | 3 +++ source/source_esolver/esolver_of.cpp | 22 ++++++++++++------- source/source_esolver/esolver_of.h | 8 +++++++ .../source_esolver/esolver_of_interface.cpp | 14 ++++++------ 5 files changed, 41 insertions(+), 21 deletions(-) diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 55c4274b95..43e22f1888 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -57,7 +57,10 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para classname = "ESolver_KS"; basisname = ""; - niter = inp.scf_nmax; + scf_thr = inp.scf_thr; + scf_ene_thr = inp.scf_ene_thr; + maxniter = inp.scf_nmax; + niter = maxniter; drho = 0.0; std::string fft_device = inp.device; @@ -232,9 +235,9 @@ void ESolver_KS::runner(UnitCell& ucell, const int istep) // 2) SCF iterations //---------------------------------------------------------------- bool conv_esolver = false; - this->niter = PARAM.inp.scf_nmax; + this->niter = this->maxniter; this->diag_ethr = PARAM.inp.pw_diag_thr; - for (int iter = 1; iter <= PARAM.inp.scf_nmax; ++iter) + for (int iter = 1; iter <= this->maxniter; ++iter) { //---------------------------------------------------------------- // 3) initialization of SCF iterations @@ -395,10 +398,10 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i } #endif - conv_esolver = (drho < PARAM.inp.scf_thr && not_restart_step && is_U_converged); + conv_esolver = (drho < this->scf_thr && not_restart_step && is_U_converged); // add energy threshold for SCF convergence - if (PARAM.inp.scf_ene_thr > 0.0) + if (this->scf_ene_thr > 0.0) { // calculate energy of output charge density this->update_pot(ucell, istep, iter, conv_esolver); @@ -412,7 +415,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i { // update the convergence flag conv_esolver - = (std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < PARAM.inp.scf_ene_thr); + = (std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < this->scf_ene_thr); } } diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index 3d940328a2..ef26d986c6 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -76,8 +76,11 @@ class ESolver_KS : public ESolver_FP std::string basisname; //! esolver_ks_lcao.cpp double esolver_KS_ne = 0.0; //! number of electrons double diag_ethr; //! the threshold for diagonalization + double scf_thr; //! scf density threshold + double scf_ene_thr; //! scf energy threshold double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver) double hsolver_error; //! the error of HSolver + int maxniter; //! maximum iter steps for scf int niter; //! iter steps actually used in scf bool oscillate_esolver = false; // whether esolver is oscillated }; diff --git a/source/source_esolver/esolver_of.cpp b/source/source_esolver/esolver_of.cpp index 3914eb2957..1ece91eb62 100644 --- a/source/source_esolver/esolver_of.cpp +++ b/source/source_esolver/esolver_of.cpp @@ -60,6 +60,12 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp) ESolver_FP::before_all_runners(ucell, inp); // save necessary parameters + this->of_kinetic_ = inp.of_kinetic; + this->of_method_ = inp.of_method; + this->of_conv_ = inp.of_conv; + this->of_tole_ = inp.of_tole; + this->of_tolp_ = inp.of_tolp; + this->max_iter_ = inp.scf_nmax; this->dV_ = ucell.omega / this->pw_rho->nxyz; this->bound_cal_potential_ = std::bind(&ESolver_OF::cal_potential, this, std::placeholders::_1, std::placeholders::_2, std::ref(ucell)); @@ -416,7 +422,7 @@ void ESolver_OF::update_rho() } /** - * @brief Check convergence, return ture if converge or iter >= PARAM.inp.scf_nmax, + * @brief Check convergence, return ture if converge or iter >= max_iter_, * and print the necessary information * * @return exit or not @@ -428,7 +434,7 @@ bool ESolver_OF::check_exit(bool& conv_esolver) bool potHold = false; // if normdLdphi nearly remains unchanged bool energyConv = false; - if (this->normdLdphi_ < PARAM.inp.of_tolp) + if (this->normdLdphi_ < this->of_tolp_) { potConv = true; } @@ -438,23 +444,23 @@ bool ESolver_OF::check_exit(bool& conv_esolver) potHold = true; } - if (this->iter_ >= 3 && std::abs(this->energy_current_ - this->energy_last_) < PARAM.inp.of_tole - && std::abs(this->energy_current_ - this->energy_llast_) < PARAM.inp.of_tole) + if (this->iter_ >= 3 && std::abs(this->energy_current_ - this->energy_last_) < this->of_tole_ + && std::abs(this->energy_current_ - this->energy_llast_) < this->of_tole_) { energyConv = true; } - conv_esolver = (PARAM.inp.of_conv == "energy" && energyConv) || (PARAM.inp.of_conv == "potential" && potConv) - || (PARAM.inp.of_conv == "both" && potConv && energyConv); + conv_esolver = (this->of_conv_ == "energy" && energyConv) || (this->of_conv_ == "potential" && potConv) + || (this->of_conv_ == "both" && potConv && energyConv); this->print_info(conv_esolver); - if (conv_esolver || this->iter_ >= PARAM.inp.scf_nmax) + if (conv_esolver || this->iter_ >= this->max_iter_) { return true; } // ============ temporary solution of potential convergence =========== - else if (PARAM.inp.of_conv == "potential" && potHold) + else if (this->of_conv_ == "potential" && potHold) { GlobalV::ofs_warning << "ESolver_OF WARNING: " << "The convergence of potential has not been reached, but the norm of potential nearly " diff --git a/source/source_esolver/esolver_of.h b/source/source_esolver/esolver_of.h index 332eb30139..508a18b3db 100644 --- a/source/source_esolver/esolver_of.h +++ b/source/source_esolver/esolver_of.h @@ -38,6 +38,14 @@ class ESolver_OF : public ESolver_FP ModuleBase::Opt_DCsrch* opt_dcsrch_ = nullptr; ModuleBase::Opt_CG* opt_cg_mag_ = nullptr; // for spin2 case, under testing + // ----------------- necessary parameters from INPUT ------------ + std::string of_kinetic_ = "wt"; // Kinetic energy functional, such as TF, VW, WT + std::string of_method_ = "tn"; // optimization method, include cg1, cg2, tn (default), bfgs + std::string of_conv_ = "energy"; // select the convergence criterion, potential, energy (default), or both + double of_tole_ = 2e-6; // tolerance of the energy change (in Ry) for determining the convergence, default=2e-6 Ry + double of_tolp_ = 1e-5; // tolerance of potential for determining the convergence, default=1e-5 in a.u. + int max_iter_ = 50; // scf_nmax + // ------------------ parameters from other module -------------- double dV_ = 0; // volume of one grid point in real space double* nelec_ = nullptr; // number of electrons with each spin diff --git a/source/source_esolver/esolver_of_interface.cpp b/source/source_esolver/esolver_of_interface.cpp index 48755aecf0..7051308a00 100644 --- a/source/source_esolver/esolver_of_interface.cpp +++ b/source/source_esolver/esolver_of_interface.cpp @@ -16,7 +16,7 @@ void ESolver_OF::init_opt() this->opt_dcsrch_ = new ModuleBase::Opt_DCsrch(); } - if (PARAM.inp.of_method == "tn") + if (this->of_method_ == "tn") { if (this->opt_tn_ == nullptr) { @@ -25,7 +25,7 @@ void ESolver_OF::init_opt() this->opt_tn_->allocate(this->pw_rho->nrxx); this->opt_tn_->set_para(this->dV_); } - else if (PARAM.inp.of_method == "cg1" || PARAM.inp.of_method == "cg2") + else if (this->of_method_ == "cg1" || this->of_method_ == "cg2") { if (this->opt_cg_ == nullptr) { @@ -35,7 +35,7 @@ void ESolver_OF::init_opt() this->opt_cg_->set_para(this->dV_); this->opt_dcsrch_->set_paras(1e-4, 1e-2); } - else if (PARAM.inp.of_method == "bfgs") + else if (this->of_method_ == "bfgs") { ModuleBase::WARNING_QUIT("esolver_of", "BFGS is not supported now."); return; @@ -62,7 +62,7 @@ void ESolver_OF::get_direction(UnitCell& ucell) { for (int is = 0; is < PARAM.inp.nspin; ++is) { - if (PARAM.inp.of_method == "tn") + if (this->of_method_ == "tn") { this->tn_spin_flag_ = is; opt_tn_->next_direct(this->pphi_[is], @@ -72,15 +72,15 @@ void ESolver_OF::get_direction(UnitCell& ucell) this, &ESolver_OF::cal_potential_wrapper); } - else if (PARAM.inp.of_method == "cg1") + else if (this->of_method_ == "cg1") { opt_cg_->next_direct(this->pdLdphi_[is], 1, this->pdirect_[is]); } - else if (PARAM.inp.of_method == "cg2") + else if (this->of_method_ == "cg2") { opt_cg_->next_direct(this->pdLdphi_[is], 2, this->pdirect_[is]); } - else if (PARAM.inp.of_method == "bfgs") + else if (this->of_method_ == "bfgs") { return; }