Skip to content

Commit 7c68dfc

Browse files
authored
Refactor: Remove redundant Input_para from ESolver Class (#6370)
* Refactor: Replace PARAM.inp with inp in ESolver classes for consistency * Refactor: Replace local input parameters with PARAM.inp in ESolver classes for consistency * Refactor: Use PARAM.inp.scf_ene_thr in ESolver_KS_LCAO iter_finish method * Revert "Refactor: Use PARAM.inp.scf_ene_thr in ESolver_KS_LCAO iter_finish method" This reverts commit b1bd0fd. * Revert "Refactor: Replace local input parameters with PARAM.inp in ESolver classes for consistency" This reverts commit f4f81e3.
1 parent b8e9264 commit 7c68dfc

File tree

7 files changed

+74
-74
lines changed

7 files changed

+74
-74
lines changed

source/source_esolver/esolver_fp.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ ESolver_FP::~ESolver_FP()
4747
void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
4848
{
4949
ModuleBase::TITLE("ESolver_FP", "before_all_runners");
50-
std::string fft_device = PARAM.inp.device;
51-
std::string fft_precison = PARAM.inp.precision;
50+
std::string fft_device = inp.device;
51+
std::string fft_precison = inp.precision;
5252
// LCAO basis doesn't support GPU acceleration on FFT currently
53-
if(PARAM.inp.basis_type == "lcao")
53+
if(inp.basis_type == "lcao")
5454
{
5555
fft_device = "cpu";
5656
}
57-
if ((PARAM.inp.precision=="single") || (PARAM.inp.precision=="mixing"))
57+
if ((inp.precision=="single") || (inp.precision=="mixing"))
5858
{
5959
fft_precison = "mixing";
6060
}
61-
else if (PARAM.inp.precision=="double")
61+
else if (inp.precision=="double")
6262
{
6363
fft_precison = "double";
6464
}
@@ -79,8 +79,8 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
7979
pw_rhod = pw_rho;
8080
}
8181
pw_big = static_cast<ModulePW::PW_Basis_Big*>(pw_rhod);
82-
pw_big->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz);
83-
sf.set(pw_rhod, PARAM.inp.nbspline);
82+
pw_big->setbxyz(inp.bx, inp.by, inp.bz);
83+
sf.set(pw_rhod, inp.nbspline);
8484

8585
//! 1) read pseudopotentials
8686
elecstate::read_pseudo(GlobalV::ofs_running, ucell);
@@ -89,7 +89,7 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
8989
#ifdef __MPI
9090
this->pw_rho->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
9191
#endif
92-
if (this->classname == "ESolver_OF" || PARAM.inp.of_ml_gene_data == 1)
92+
if (this->classname == "ESolver_OF" || inp.of_ml_gene_data == 1)
9393
{
9494
this->pw_rho->setfullpw(inp.of_full_pw, inp.of_full_pw_dim);
9595
}
@@ -143,7 +143,7 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
143143
ModuleIO::print_rhofft(this->pw_rhod, this->pw_rho, this->pw_big, GlobalV::ofs_running);
144144

145145
//! 5) initialize the charge extrapolation method if necessary
146-
this->CE.Init_CE(PARAM.inp.nspin, ucell.nat, this->pw_rhod->nrxx, inp.chg_extrap);
146+
this->CE.Init_CE(inp.nspin, ucell.nat, this->pw_rhod->nrxx, inp.chg_extrap);
147147

148148
return;
149149
}

source/source_esolver/esolver_ks.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,23 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
5757
classname = "ESolver_KS";
5858
basisname = "";
5959

60-
scf_thr = PARAM.inp.scf_thr;
61-
scf_ene_thr = PARAM.inp.scf_ene_thr;
62-
maxniter = PARAM.inp.scf_nmax;
60+
scf_thr = inp.scf_thr;
61+
scf_ene_thr = inp.scf_ene_thr;
62+
maxniter = inp.scf_nmax;
6363
niter = maxniter;
6464
drho = 0.0;
6565

66-
std::string fft_device = PARAM.inp.device;
66+
std::string fft_device = inp.device;
6767

6868
// Fast Fourier Transform
6969
// LCAO basis doesn't support GPU acceleration on FFT currently
70-
if(PARAM.inp.basis_type == "lcao")
70+
if(inp.basis_type == "lcao")
7171
{
7272
fft_device = "cpu";
7373
}
74-
std::string fft_precision = PARAM.inp.precision;
74+
std::string fft_precision = inp.precision;
7575
#ifdef __ENABLE_FLOAT_FFTW
76-
if (PARAM.inp.cal_cond && PARAM.inp.esolver_type == "sdft")
76+
if (inp.cal_cond && inp.esolver_type == "sdft")
7777
{
7878
fft_precision = "mixing";
7979
}
@@ -83,7 +83,7 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
8383
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);
8484

8585
// should not use INPUT here, mohan 2024-05-12
86-
tmp->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz);
86+
tmp->setbxyz(inp.bx, inp.by, inp.bz);
8787

8888
///----------------------------------------------------------
8989
/// charge mixing
@@ -92,7 +92,7 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
9292
p_chgmix->set_rhopw(this->pw_rho, this->pw_rhod);
9393

9494
// cell_factor
95-
this->ppcell.cell_factor = PARAM.inp.cell_factor;
95+
this->ppcell.cell_factor = inp.cell_factor;
9696

9797

9898
//! 3) it has been established that
@@ -103,16 +103,16 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
103103
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL");
104104

105105
//! 4) setup the charge mixing parameters
106-
p_chgmix->set_mixing(PARAM.inp.mixing_mode,
107-
PARAM.inp.mixing_beta,
108-
PARAM.inp.mixing_ndim,
109-
PARAM.inp.mixing_gg0,
110-
PARAM.inp.mixing_tau,
111-
PARAM.inp.mixing_beta_mag,
112-
PARAM.inp.mixing_gg0_mag,
113-
PARAM.inp.mixing_gg0_min,
114-
PARAM.inp.mixing_angle,
115-
PARAM.inp.mixing_dmr,
106+
p_chgmix->set_mixing(inp.mixing_mode,
107+
inp.mixing_beta,
108+
inp.mixing_ndim,
109+
inp.mixing_gg0,
110+
inp.mixing_tau,
111+
inp.mixing_beta_mag,
112+
inp.mixing_gg0_mag,
113+
inp.mixing_gg0_min,
114+
inp.mixing_angle,
115+
inp.mixing_dmr,
116116
ucell.omega,
117117
ucell.tpiba);
118118

@@ -127,7 +127,7 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
127127
}
128128

129129
//! 6) Setup the k points according to symmetry.
130-
this->kv.set(ucell,ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
130+
this->kv.set(ucell,ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
131131
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");
132132

133133
//! 7) print information

source/source_esolver/esolver_ks_lcao.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
139139
int ncol = 0;
140140
if (PARAM.globalv.gamma_only_local)
141141
{
142-
nsk = PARAM.inp.nspin;
142+
nsk = inp.nspin;
143143
ncol = this->pv.ncol_bands;
144-
if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver == "lapack"
145-
|| PARAM.inp.ks_solver == "pexsi" || PARAM.inp.ks_solver == "cusolver"
146-
|| PARAM.inp.ks_solver == "cusolvermp")
144+
if (inp.ks_solver == "genelpa" || inp.ks_solver == "elpa" || inp.ks_solver == "lapack"
145+
|| inp.ks_solver == "pexsi" || inp.ks_solver == "cusolver"
146+
|| inp.ks_solver == "cusolvermp")
147147
{
148148
ncol = this->pv.ncol;
149149
}
@@ -154,22 +154,22 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
154154
#ifdef __MPI
155155
ncol = this->pv.ncol_bands;
156156
#else
157-
ncol = PARAM.inp.nbands;
157+
ncol = inp.nbands;
158158
#endif
159159
}
160160
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, this->kv.ngk, true);
161161
}
162162

163163
// 5) read psi from file
164-
if (PARAM.inp.init_wfc == "file"&& PARAM.inp.esolver_type != "tddft")
164+
if (inp.init_wfc == "file" && inp.esolver_type != "tddft")
165165
{
166166
if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir,
167167
this->pv,
168168
*(this->psi),
169169
this->pelec,
170170
this->pelec->klist->ik2iktot,
171171
this->pelec->klist->get_nkstot(),
172-
PARAM.inp.nspin))
172+
inp.nspin))
173173
{
174174
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "read electronic wave functions failed");
175175
}
@@ -178,16 +178,16 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
178178
// 6) initialize the density matrix
179179
// DensityMatrix is allocated here, DMK is also initialized here
180180
// DMR is not initialized here, it will be constructed in each before_scf
181-
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), PARAM.inp.nspin);
181+
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin);
182182

183183
// 7) initialize exact exchange calculations
184184
#ifdef __EXX
185-
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax"
186-
|| PARAM.inp.calculation == "md")
185+
if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax"
186+
|| inp.calculation == "md")
187187
{
188188
if (GlobalC::exx_info.info_global.cal_exx)
189189
{
190-
if (PARAM.inp.init_wfc != "file")
190+
if (inp.init_wfc != "file")
191191
{ // if init_wfc==file, directly enter the EXX loop
192192
XC_Functional::set_xc_first_loop(ucell);
193193
}
@@ -208,7 +208,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
208208
#endif
209209

210210
// 8) initialize DFT+U
211-
if (PARAM.inp.dft_plus_u)
211+
if (inp.dft_plus_u)
212212
{
213213
auto* dftu = ModuleDFTU::DFTU::get_instance();
214214
dftu->init(ucell, &this->pv, this->kv.get_nks(), &orb_);
@@ -219,7 +219,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
219219
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL");
220220

221221
// 10) inititlize the charge density
222-
this->chr.allocate(PARAM.inp.nspin);
222+
this->chr.allocate(inp.nspin);
223223
this->pelec->omega = ucell.omega;
224224

225225
// 11) initialize the potential
@@ -238,13 +238,13 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
238238
// 12) initialize deepks
239239
#ifdef __MLALGO
240240
LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld, GlobalV::ofs_running);
241-
if (PARAM.inp.deepks_scf)
241+
if (inp.deepks_scf)
242242
{
243243
// load the DeePKS model from deep neural network
244-
DeePKS_domain::load_model(PARAM.inp.deepks_model, ld.model_deepks);
244+
DeePKS_domain::load_model(inp.deepks_model, ld.model_deepks);
245245
// read pdm from file for NSCF or SCF-restart, do it only once in whole calculation
246-
DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"),
247-
PARAM.inp.deepks_equiv,
246+
DeePKS_domain::read_pdm((inp.init_chg == "file"),
247+
inp.deepks_equiv,
248248
ld.init_pdm,
249249
ucell.nat,
250250
orb_.Alpha[0].getTotal_nchi() * ucell.nat,
@@ -257,11 +257,11 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
257257

258258
// 13) set occupations
259259
// tddft does not need to set occupations in the first scf
260-
if (PARAM.inp.ocp && inp.esolver_type != "tddft")
260+
if (inp.ocp && inp.esolver_type != "tddft")
261261
{
262-
elecstate::fixed_weights(PARAM.inp.ocp_kb,
263-
PARAM.inp.nbands,
264-
PARAM.inp.nelec,
262+
elecstate::fixed_weights(inp.ocp_kb,
263+
inp.nbands,
264+
inp.nelec,
265265
this->pelec->klist,
266266
this->pelec->wg,
267267
this->pelec->skip_weights);
@@ -289,7 +289,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
289289
}
290290

291291
// 15) initialize rdmft, added by jghan
292-
if (PARAM.inp.rdmft == true)
292+
if (inp.rdmft == true)
293293
{
294294
rdmft_solver.init(this->GG,
295295
this->GK,
@@ -300,8 +300,8 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
300300
*(this->pelec),
301301
this->orb_,
302302
two_center_bundle_,
303-
PARAM.inp.dft_functional,
304-
PARAM.inp.rdmft_power_alpha);
303+
inp.dft_functional,
304+
inp.rdmft_power_alpha);
305305
}
306306

307307
ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners");

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ namespace ModuleESolver
9393
this->kv.ngk,
9494
true);
9595
#ifdef __EXX
96-
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
97-
|| PARAM.inp.calculation == "cell-relax"
98-
|| PARAM.inp.calculation == "md") {
96+
if (inp.calculation == "scf" || inp.calculation == "relax"
97+
|| inp.calculation == "cell-relax"
98+
|| inp.calculation == "md") {
9999
if (GlobalC::exx_info.info_global.cal_exx)
100100
{
101101
XC_Functional::set_xc_first_loop(ucell);

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
169169
this->pelec->omega = ucell.omega;
170170

171171
//! 3) inititlize the charge density.
172-
this->chr.allocate(PARAM.inp.nspin);
172+
this->chr.allocate(inp.nspin);
173173

174174
//! 4) initialize the potential.
175175
if (this->pelec->pot == nullptr)
@@ -194,9 +194,9 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
194194
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL");
195195

196196
//! 7) Allocate and initialize psi
197-
this->p_psi_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
198-
PARAM.inp.ks_solver,
199-
PARAM.inp.basis_type,
197+
this->p_psi_init = new psi::PSIInit<T, Device>(inp.init_wfc,
198+
inp.ks_solver,
199+
inp.basis_type,
200200
GlobalV::MY_RANK,
201201
ucell,
202202
this->sf,
@@ -206,28 +206,28 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
206206

207207
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max);
208208

209-
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
209+
this->p_psi_init->prepare_init(inp.pw_seed);
210210

211-
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
211+
this->kspw_psi = inp.device == "gpu" || inp.precision == "single"
212212
? new psi::Psi<T, Device>(this->psi[0])
213213
: reinterpret_cast<psi::Psi<T, Device>*>(this->psi);
214214

215215
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS");
216216

217217
//! 8) setup occupations
218-
if (PARAM.inp.ocp)
218+
if (inp.ocp)
219219
{
220-
elecstate::fixed_weights(PARAM.inp.ocp_kb,
221-
PARAM.inp.nbands,
222-
PARAM.inp.nelec,
220+
elecstate::fixed_weights(inp.ocp_kb,
221+
inp.nbands,
222+
inp.nelec,
223223
this->pelec->klist,
224224
this->pelec->wg,
225225
this->pelec->skip_weights);
226226
}
227227

228228
// 9) initialize exx pw
229-
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax"
230-
|| PARAM.inp.calculation == "md")
229+
if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax"
230+
|| inp.calculation == "md")
231231
{
232232
if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true)
233233
{

source/source_esolver/esolver_of.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp)
8787
}
8888

8989
// Setup the k points according to symmetry.
90-
kv.set(ucell,ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
90+
kv.set(ucell,ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
9191
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");
9292

9393
// print information
@@ -127,12 +127,12 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp)
127127

128128
// Initialize KEDF
129129
// Calculate electron numbers, which will be used to initialize WT KEDF
130-
this->nelec_ = new double[PARAM.inp.nspin];
131-
if (PARAM.inp.nspin == 1)
130+
this->nelec_ = new double[inp.nspin];
131+
if (inp.nspin == 1)
132132
{
133-
this->nelec_[0] = PARAM.inp.nelec;
133+
this->nelec_[0] = inp.nelec;
134134
}
135-
else if (PARAM.inp.nspin == 2)
135+
else if (inp.nspin == 2)
136136
{
137137
// in fact, nelec_spin will not be used anymore
138138
this->pelec->init_nelec_spin();

source/source_esolver/esolver_sdft_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input
7272
true);
7373
ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T));
7474

75-
if (PARAM.inp.nbands > 0)
75+
if (inp.nbands > 0)
7676
{
7777
this->stowf.chiortho
7878
= new psi::Psi<T, Device>(this->kv.get_nks(),

0 commit comments

Comments
 (0)