Skip to content

Commit 59973ea

Browse files
authored
Make Esolver simpler (#6557)
* update small places in charge_mixing_residual.cpp * update charge_mixing_preconditioner * add timers and remove some PARAM.inp.nspin * fix problems * add timers * small fix * fix a potential memory leak * fix bug * add ctrl_output_pw files but cannot run now * add some interfaces in ctrl_output_pw * keep fixing bugs * successfully compile the codes * finally I understand Devicegit add ../source/source_base/module_device/device.h ../source/source_io/ctrl_output_pw.cpp ../source/source_io/ctrl_output_pw.h! * update function variables * one step further * move on * update codes, done! * fix bug * add setup_pot * fix bugs * update esolver_ks_pw, add setup_pot * update format of esolver_ks_pw.cpp * update setup_pot for GPU version * update esolver_ks_pw.cpp * add setup_estate_pw in source_estate * fix some bugs * small update * move teardown to deconstructor of ESolver * fix error in passing pointers * move teardown function to the correct place * fix bugs * update * add two functions in module_pwdft * add setup_pwrho file, and update esolver_fp * update esolver_fp * add teardown function in setup_pwrho * update esolver * move the after_all_runners() function of the base class to the end * add TITLE for after_all_runners * update esolver_fp, move teardown to desconstructor * Revert "update esolver_fp, move teardown to desconstructor" This reverts commit 8516477. * Revert "add TITLE for after_all_runners" This reverts commit 16ba5d2. * Revert "move the after_all_runners() function of the base class to the end" This reverts commit d672bca.
1 parent a9ac67f commit 59973ea

24 files changed

+839
-454
lines changed

source/Makefile.Objects

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ OBJS_ELECSTAT=elecstate.o\
238238
gatefield.o\
239239
potential_new.o\
240240
potential_types.o\
241+
pot_sep.o\
241242
pot_local.o\
242243
pot_local_paw.o\
243244
H_Hartree_pw.o\
@@ -248,7 +249,7 @@ OBJS_ELECSTAT=elecstate.o\
248249
cal_nelec_nband.o\
249250
read_pseudo.o\
250251
cal_wfc.o\
251-
pot_sep.o\
252+
setup_estate_pw.o\
252253

253254
OBJS_ELECSTAT_LCAO=elecstate_lcao.o\
254255
elecstate_lcao_cal_tau.o\
@@ -714,6 +715,8 @@ OBJS_SRCPW=H_Ewald_pw.o\
714715
charge_mixing_rho.o\
715716
charge_mixing_uspp.o\
716717
fp_energy.o\
718+
setup_pot.o\
719+
setup_pwrho.o\
717720
forces.o\
718721
forces_us.o\
719722
forces_nl.o\

source/source_esolver/esolver.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class ESolver
1717

1818
virtual ~ESolver()
1919
{
20+
//****************************************************
21+
// do not add any codes in this deconstructor funcion
22+
//****************************************************
2023
}
2124

2225
//! initialize the energy solver by using input parameters and cell modules

source/source_esolver/esolver_fp.cpp

Lines changed: 25 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
#include "esolver_fp.h"
22

3-
#include "source_base/global_variable.h"
43
#include "source_estate/cal_ux.h"
54
#include "source_estate/module_charge/symmetry_rho.h"
65
#include "source_estate/read_pseudo.h"
76
#include "source_hamilt/module_ewald/H_Ewald_pw.h"
87
#include "source_hamilt/module_vdw/vdw.h"
9-
#include "source_pw/module_pwdft/global.h"
108
#include "source_io/cif_io.h"
119
#include "source_io/cube_io.h" // use write_vdata_palgrid
1210
#include "source_io/json_output/init_info.h"
@@ -15,9 +13,11 @@
1513
#include "source_io/print_info.h"
1614
#include "source_io/rhog_io.h"
1715
#include "source_io/module_parameter/parameter.h"
18-
#include "source_cell/k_vector_utils.h"
1916
#include "source_io/ctrl_output_fp.h"
2017

18+
#include "source_pw/module_pwdft/setup_pwrho.h" // mohan 20251005
19+
#include "source_hamilt/module_xc/xc_functional.h" // mohan 20251005
20+
2121
namespace ModuleESolver
2222
{
2323

@@ -27,117 +27,36 @@ ESolver_FP::ESolver_FP()
2727

2828
ESolver_FP::~ESolver_FP()
2929
{
30-
if (pw_rho_flag == true)
31-
{
32-
delete this->pw_rho;
33-
this->pw_rho_flag = false;
34-
}
35-
if (PARAM.globalv.double_grid)
36-
{
37-
delete pw_rhod;
38-
}
39-
delete this->pelec;
30+
//****************************************************
31+
// do not add any codes in this deconstructor funcion
32+
//****************************************************
33+
// mohan add 20251005
34+
pw::teardown_pwrho(this->pw_rho_flag, PARAM.globalv.double_grid, this->pw_rho, this->pw_rhod);
35+
36+
delete this->pelec;
4037
}
4138

4239
void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
4340
{
4441
ModuleBase::TITLE("ESolver_FP", "before_all_runners");
45-
std::string fft_device = inp.device;
46-
std::string fft_precison = inp.precision;
47-
// LCAO basis doesn't support GPU acceleration on FFT currently
48-
if(inp.basis_type == "lcao")
49-
{
50-
fft_device = "cpu";
51-
}
52-
if ((inp.precision=="single") || (inp.precision=="mixing"))
53-
{
54-
fft_precison = "mixing";
55-
}
56-
else if (inp.precision=="double")
57-
{
58-
fft_precison = "double";
59-
}
60-
#if (not defined(__ENABLE_FLOAT_FFTW) and (defined(__CUDA) || defined(__RCOM)))
61-
if (fft_device == "gpu")
62-
{
63-
fft_precison = "double";
64-
}
65-
#endif
66-
pw_rho = new ModulePW::PW_Basis_Big(fft_device, fft_precison);
67-
pw_rho_flag = true;
68-
if (PARAM.globalv.double_grid)
69-
{
70-
pw_rhod = new ModulePW::PW_Basis_Big(fft_device, fft_precison);
71-
}
72-
else
73-
{
74-
pw_rhod = pw_rho;
75-
}
76-
pw_big = static_cast<ModulePW::PW_Basis_Big*>(pw_rhod);
77-
pw_big->setbxyz(inp.bx, inp.by, inp.bz);
78-
sf.set(pw_rhod, inp.nbspline);
7942

80-
//! 1) read pseudopotentials
43+
//! read pseudopotentials
8144
elecstate::read_pseudo(GlobalV::ofs_running, ucell);
8245

83-
//! 2) initialie the plane wave basis for rho
84-
#ifdef __MPI
85-
this->pw_rho->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
86-
#endif
87-
if (this->classname == "ESolver_OF" || inp.of_ml_gene_data == 1)
88-
{
89-
this->pw_rho->setfullpw(inp.of_full_pw, inp.of_full_pw_dim);
90-
}
46+
// setup pw_rho, pw_rhod, pw_big, sf, and read_pseudopotentials
47+
pw::setup_pwrho(ucell, PARAM.globalv.double_grid, this->pw_rho_flag,
48+
this->pw_rho, this->pw_rhod, this->pw_big,
49+
this->classname, inp);
9150

92-
if (inp.nx * inp.ny * inp.nz == 0)
93-
{
94-
this->pw_rho->initgrids(inp.ref_cell_factor * ucell.lat0, ucell.latvec, 4.0 * inp.ecutwfc);
95-
}
96-
else
97-
{
98-
this->pw_rho->initgrids(inp.ref_cell_factor * ucell.lat0, ucell.latvec, inp.nx, inp.ny, inp.nz);
99-
}
51+
// setup the structure factors
52+
this->sf.set(this->pw_rhod, inp.nbspline);
10053

101-
this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc);
102-
this->pw_rho->fft_bundle.initfftmode(inp.fft_mode);
103-
this->pw_rho->setuptransform();
104-
this->pw_rho->collect_local_pw();
105-
this->pw_rho->collect_uniqgg();
106-
107-
//! 3) initialize the double grid (for uspp) if necessary
108-
if ( PARAM.globalv.double_grid)
109-
{
110-
ModulePW::PW_Basis_Sup* pw_rhod_sup = static_cast<ModulePW::PW_Basis_Sup*>(pw_rhod);
111-
#ifdef __MPI
112-
this->pw_rhod->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
113-
#endif
114-
if (this->classname == "ESolver_OF")
115-
{
116-
this->pw_rhod->setfullpw(inp.of_full_pw, inp.of_full_pw_dim);
117-
}
118-
if (inp.ndx * inp.ndy * inp.ndz == 0)
119-
{
120-
this->pw_rhod->initgrids(inp.ref_cell_factor * ucell.lat0, ucell.latvec, inp.ecutrho);
121-
}
122-
else
123-
{
124-
this->pw_rhod->initgrids(inp.ref_cell_factor * ucell.lat0, ucell.latvec, inp.ndx, inp.ndy, inp.ndz);
125-
}
126-
this->pw_rhod->initparameters(false, inp.ecutrho);
127-
this->pw_rhod->fft_bundle.initfftmode(inp.fft_mode);
128-
pw_rhod_sup->setuptransform(this->pw_rho);
129-
this->pw_rhod->collect_local_pw();
130-
this->pw_rhod->collect_uniqgg();
131-
}
13254
ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU.cif",
13355
ucell,
13456
"# Generated by ABACUS ModuleIO::CifParser",
13557
"data_?");
13658

137-
//! 4) print some information
138-
ModuleIO::print_rhofft(this->pw_rhod, this->pw_rho, this->pw_big, GlobalV::ofs_running);
139-
140-
//! 5) initialize the charge extrapolation method if necessary
59+
//! initialize the charge extrapolation method if necessary
14160
this->CE.Init_CE(inp.nspin, ucell.nat, this->pw_rhod->nrxx, inp.chg_extrap);
14261

14362
return;
@@ -148,16 +67,16 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
14867
{
14968
ModuleBase::TITLE("ESolver_FP", "after_scf");
15069

151-
// 1) output convergence information
70+
//! Output convergence information
15271
ModuleIO::output_convergence_after_scf(conv_esolver, this->pelec->f_en.etot);
15372

154-
// 2) write fermi energy
73+
//! Write Fermi energy
15574
ModuleIO::output_efermi(conv_esolver, this->pelec->eferm.ef);
15675

157-
// 3) update delta_rho for charge extrapolation
76+
//! Update delta_rho for charge extrapolation
15877
CE.update_delta_rho(ucell, &(this->chr), &(this->sf));
15978

160-
// 4) print out charge density, potential, elf, etc.
79+
//! print out charge density, potential, elf, etc.
16180
ModuleIO::ctrl_output_fp(ucell, this->pelec, this->pw_big, this->pw_rhod,
16281
this->chr, this->solvent, this->Pgrid, istep);
16382

@@ -202,9 +121,7 @@ void ESolver_FP::before_scf(UnitCell& ucell, const int istep)
202121
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");
203122
}
204123

205-
//----------------------------------------------------------
206124
// charge extrapolation
207-
//----------------------------------------------------------
208125
if (ucell.ionic_position_updated)
209126
{
210127
this->CE.update_all_dis(ucell);
@@ -216,33 +133,23 @@ void ESolver_FP::before_scf(UnitCell& ucell, const int istep)
216133
GlobalV::ofs_warning);
217134
}
218135

219-
//----------------------------------------------------------
220136
//! calculate D2 or D3 vdW
221-
//----------------------------------------------------------
222137
auto vdw_solver = vdw::make_vdw(ucell, PARAM.inp, &(GlobalV::ofs_running));
223138
if (vdw_solver != nullptr)
224139
{
225140
this->pelec->f_en.evdw = vdw_solver->get_energy();
226141
}
227142

228-
//----------------------------------------------------------
229143
//! calculate ewald energy
230-
//----------------------------------------------------------
231144
if (!PARAM.inp.test_skip_ewald)
232145
{
233146
this->pelec->f_en.ewald_energy = H_Ewald_pw::compute_ewald(ucell, this->pw_rhod, this->sf.strucFac);
234147
}
235148

236-
//----------------------------------------------------------
237149
//! set direction of magnetism, used in non-collinear case
238-
//----------------------------------------------------------
239150
elecstate::cal_ux(ucell);
240151

241-
242-
243-
//----------------------------------------------------------
244152
//! output the initial charge density
245-
//----------------------------------------------------------
246153
const int nspin = PARAM.inp.nspin;
247154
if (PARAM.inp.out_chg[0] == 2)
248155
{
@@ -271,9 +178,7 @@ void ESolver_FP::before_scf(UnitCell& ucell, const int istep)
271178
}
272179
}
273180

274-
//----------------------------------------------------------
275181
//! output total local potential of the initial charge density
276-
//----------------------------------------------------------
277182
if (PARAM.inp.out_pot == 3)
278183
{
279184
for (int is = 0; is < nspin; is++)
@@ -308,7 +213,7 @@ void ESolver_FP::before_scf(UnitCell& ucell, const int istep)
308213

309214
void ESolver_FP::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver)
310215
{
311-
//! output charge density
216+
//! output charge density in G-space, or if available, kinetic energy density in G-space
312217
if (PARAM.inp.out_chg[0] != -1)
313218
{
314219
if (iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver)
@@ -352,10 +257,12 @@ void ESolver_FP::iter_finish(UnitCell& ucell, const int istep, int& iter, bool&
352257

353258
void ESolver_FP::after_all_runners(UnitCell& ucell)
354259
{
260+
// print out the final total energy
355261
GlobalV::ofs_running << "\n --------------------------------------------" << std::endl;
356262
GlobalV::ofs_running << std::setprecision(16);
357263
GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
358264
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
265+
359266
}
360267

361268
} // namespace ModuleESolver

source/source_esolver/esolver_fp.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ class ESolver_FP: public ESolver
4747

4848
virtual void iter_finish(UnitCell& ucell, const int istep, int& iter, bool &conv_esolver);
4949

50-
//! ------------------------------------------------------------------------------
5150
//! These pointers will be deleted in the free_pointers() function every ion step.
52-
//! ------------------------------------------------------------------------------
5351
elecstate::ElecState* pelec = nullptr; ///< Electronic states
5452

5553
//! K points in Brillouin zone
@@ -82,7 +80,7 @@ class ESolver_FP: public ESolver
8280
//! solvent model
8381
surchem solvent;
8482

85-
int pw_rho_flag = false; ///< flag for pw_rho, 0: not initialized, 1: initialized
83+
bool pw_rho_flag = false; ///< flag for pw_rho, 0: not initialized, 1: initialized
8684

8785
//! the start time of scf iteration
8886
#ifdef __MPI

source/source_esolver/esolver_ks.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "source_io/json_output/output_info.h"
2727

2828

29-
3029
namespace ModuleESolver
3130
{
3231

@@ -39,7 +38,10 @@ ESolver_KS<T, Device>::ESolver_KS()
3938
template <typename T, typename Device>
4039
ESolver_KS<T, Device>::~ESolver_KS()
4140
{
42-
delete this->psi;
41+
//****************************************************
42+
// do not add any codes in this deconstructor funcion
43+
//****************************************************
44+
delete this->psi;
4345
delete this->pw_wfc;
4446
delete this->p_hamilt;
4547
delete this->p_chgmix;

source/source_esolver/esolver_ks_lcao.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ ESolver_KS_LCAO<TK, TR>::ESolver_KS_LCAO()
9797
template <typename TK, typename TR>
9898
ESolver_KS_LCAO<TK, TR>::~ESolver_KS_LCAO()
9999
{
100+
//****************************************************
101+
// do not add any codes in this deconstructor funcion
102+
//****************************************************
100103
}
101104

102105
template <typename TK, typename TR>

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ ESolver_KS_LCAO_TDDFT<TR, Device>::ESolver_KS_LCAO_TDDFT()
5757
template <typename TR, typename Device>
5858
ESolver_KS_LCAO_TDDFT<TR, Device>::~ESolver_KS_LCAO_TDDFT()
5959
{
60-
delete psi_laststep;
60+
//****************************************************
61+
// do not add any codes in this deconstructor funcion
62+
//****************************************************
63+
delete psi_laststep;
6164
if (Hk_laststep != nullptr)
6265
{
6366
for (int ik = 0; ik < this->kv.get_nks(); ++ik)

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ namespace ModuleESolver
5151
template <typename T>
5252
ESolver_KS_LIP<T>::~ESolver_KS_LIP()
5353
{
54+
//****************************************************
55+
// do not add any codes in this deconstructor funcion
56+
//****************************************************
5457
delete this->psi_local;
5558
// delete Hamilt
5659
this->deallocate_hamilt();

0 commit comments

Comments
 (0)