Skip to content

Commit b2a0c8a

Browse files
authored
Add an interface for psi in esolver_ks_pw (#6599)
* add psi interface for pw * update setup_psi * update pis * keep updating psi * fix things * update psi fix bugs * update psi * update psi * update, now can be compiled successfully * fix bug * update cmake in module_psi * update esolver_ks * fix error introduced by esolver_ks's psi, which should be esolver_ks_pw's new psi (stp.psi) * change function name
1 parent 4cfd8f5 commit b2a0c8a

File tree

11 files changed

+269
-188
lines changed

11 files changed

+269
-188
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
764764
of_stress_pw.o\
765765
symmetry_rho.o\
766766
symmetry_rhog.o\
767+
setup_psi.o\
767768
psi_init.o\
768769
elecond.o\
769770
sto_tool.o\

source/source_esolver/esolver_ks.cpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
334334
//----------------------------------------------------------------
335335
// 2) compute magnetization, only for LSDA(spin==2)
336336
//----------------------------------------------------------------
337-
ucell.magnet.compute_mag(ucell.omega,
338-
this->chr.nrxx,
339-
this->chr.nxyz,
340-
this->chr.rho,
337+
ucell.magnet.compute_mag(ucell.omega, this->chr.nrxx, this->chr.nxyz, this->chr.rho,
341338
this->pelec->nelec_spin.data());
342339

343340
//----------------------------------------------------------------
@@ -434,20 +431,15 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
434431
MPI_Bcast(this->chr.rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, BP_WORLD);
435432
#endif
436433

437-
//----------------------------------------------------------------
438434
// 4) Update potentials (should be done every SF iter)
439-
//----------------------------------------------------------------
440-
// Hamilt should be used after it is constructed.
441-
// this->phamilt->update(conv_esolver);
442435
this->update_pot(ucell, istep, iter, conv_esolver);
443436

444-
//----------------------------------------------------------------
445437
// 5) calculate energies
446-
//----------------------------------------------------------------
447438
// 1 means Harris-Foulkes functional
448439
// 2 means Kohn-Sham functional
449440
this->pelec->cal_energies(1);
450441
this->pelec->cal_energies(2);
442+
451443
if (iter == 1)
452444
{
453445
this->pelec->f_en.etot_old = this->pelec->f_en.etot;
@@ -456,7 +448,6 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
456448
this->pelec->f_en.etot_old = this->pelec->f_en.etot;
457449

458450

459-
460451
//----------------------------------------------------------------
461452
// 6) time and meta-GGA
462453
//----------------------------------------------------------------
@@ -481,21 +472,15 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
481472

482473

483474
#ifdef __RAPIDJSON
484-
//----------------------------------------------------------------
485475
// 7) add Json of scf mag
486-
//----------------------------------------------------------------
487-
Json::add_output_scf_mag(ucell.magnet.tot_mag,
488-
ucell.magnet.abs_mag,
476+
Json::add_output_scf_mag(ucell.magnet.tot_mag, ucell.magnet.abs_mag,
489477
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
490478
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
491-
drho,
492-
duration);
479+
drho, duration);
493480
#endif //__RAPIDJSON
494481

495482

496-
//----------------------------------------------------------------
497483
// 7) SCF restart information
498-
//----------------------------------------------------------------
499484
if (PARAM.inp.mixing_restart > 0
500485
&& iter == this->p_chgmix->mixing_restart_step - 1
501486
&& iter != PARAM.inp.scf_nmax)
@@ -504,9 +489,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
504489
std::cout << " SCF restart after this step!" << std::endl;
505490
}
506491

507-
//----------------------------------------------------------------
508492
// 8) Iter finish
509-
//----------------------------------------------------------------
510493
ESolver_FP::iter_finish(ucell, istep, iter, conv_esolver);
511494
}
512495

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,17 @@ namespace ModuleESolver
8181
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
8282
{
8383
ESolver_KS_PW<T>::before_scf(ucell, istep);
84-
this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
84+
this->stp.p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
8585
}
8686

8787
template <typename T>
8888
void ESolver_KS_LIP<T>::before_all_runners(UnitCell& ucell, const Input_para& inp)
8989
{
9090
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
9191
delete this->psi_local;
92-
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
93-
this->p_psi_init->psi_initer->nbands_start(),
94-
this->psi->get_nbasis(),
92+
this->psi_local = new psi::Psi<T>(this->stp.psi_cpu->get_nk(),
93+
this->stp.p_psi_init->psi_initer->nbands_start(),
94+
this->stp.psi_cpu->get_nbasis(),
9595
this->kv.ngk,
9696
true);
9797
#ifdef __EXX
@@ -105,13 +105,12 @@ namespace ModuleESolver
105105
ucell.symm,
106106
&this->kv,
107107
this->psi_local,
108-
this->kspw_psi,
108+
this->stp.psi_t,
109109
this->pw_wfc,
110110
this->pw_rho,
111111
this->sf,
112112
&ucell,
113113
this->pelec));
114-
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
115114
}
116115
}
117116
#endif
@@ -147,7 +146,8 @@ namespace ModuleESolver
147146
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
148147

149148
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
150-
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
149+
hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec,
150+
*this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
151151

152152
// add exx
153153
#ifdef __EXX
@@ -244,7 +244,7 @@ namespace ModuleESolver
244244
ModuleIO::write_Vxc(PARAM.inp.nspin,
245245
PARAM.globalv.nlocal,
246246
GlobalV::DRANK,
247-
*this->kspw_psi,
247+
*this->stp.psi_t,
248248
ucell,
249249
this->sf,
250250
this->solvent,

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 27 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,9 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
4949
// delete Hamilt
5050
this->deallocate_hamilt();
5151

52-
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
53-
{
54-
delete this->kspw_psi;
55-
}
56-
if (PARAM.inp.precision == "single")
57-
{
58-
delete this->__kspw_psi;
59-
}
52+
// mohan add 2025-10-12
53+
this->stp.clean();
6054

61-
delete this->psi;
62-
delete this->p_psi_init;
6355
}
6456

6557
template <typename T, typename Device>
@@ -89,18 +81,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
8981
this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho,
9082
this->pw_rhod, this->pw_big, this->solvent, inp);
9183

92-
//! Allocate and initialize psi
93-
this->p_psi_init = new psi::PSIInit<T, Device>(inp.init_wfc,
94-
inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell,
95-
this->sf, this->kv, this->ppcell, *this->pw_wfc);
96-
97-
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max);
98-
99-
this->p_psi_init->prepare_init(inp.pw_seed);
100-
101-
this->kspw_psi = inp.device == "gpu" || inp.precision == "single"
102-
? new psi::Psi<T, Device>(this->psi[0])
103-
: reinterpret_cast<psi::Psi<T, Device>*>(this->psi);
84+
this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp);
10485

10586
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS");
10687

@@ -142,7 +123,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
142123

143124
this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);
144125

145-
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
126+
this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed);
146127
}
147128

148129
//! Init Hamiltonian (cell changed)
@@ -156,14 +137,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
156137
//! Setup potentials (local, non-local, sc, +U, DFT-1/2)
157138
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
158139
this->chr, this->locpp, this->ppcell, this->vsep_cell,
159-
this->kspw_psi, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp);
140+
this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp);
160141

161-
//! Initialize wave functions
162-
if (!this->already_initpsi)
163-
{
164-
this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
165-
this->already_initpsi = true;
166-
}
142+
143+
this->stp.init(this->p_hamilt);
167144

168145
//! Exx calculations
169146
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
@@ -173,7 +150,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
173150
{
174151
auto hamilt_pw = reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
175152
hamilt_pw->set_exx_helper(exx_helper);
176-
exx_helper.set_psi(kspw_psi);
153+
exx_helper.set_psi(this->stp.psi_t);
177154
}
178155
}
179156

@@ -202,7 +179,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const
202179
// new DFT+U method will calculate energy when evaluating the Hamiltonian
203180
if (dftu->omc != 2)
204181
{
205-
dftu->cal_occ_pw(iter, this->kspw_psi, this->pelec->wg, ucell, PARAM.inp.mixing_beta);
182+
dftu->cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta);
206183
}
207184
dftu->output(ucell);
208185
}
@@ -271,7 +248,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
271248
PARAM.inp.use_k_continuity);
272249

273250
hsolver_pw_obj.solve(this->p_hamilt,
274-
this->kspw_psi[0],
251+
this->stp.psi_t[0],
275252
this->pelec,
276253
this->pelec->ekb.c,
277254
GlobalV::RANK_IN_POOL,
@@ -316,7 +293,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
316293
// Related to EXX
317294
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter)
318295
{
319-
this->pelec->set_exx(exx_helper.cal_exx_energy(kspw_psi));
296+
this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.psi_t));
320297
}
321298

322299
// deband is calculated from "output" charge density
@@ -347,12 +324,12 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
347324
double dexx = 0.0;
348325
if (PARAM.inp.exx_thr_type == "energy")
349326
{
350-
dexx = exx_helper.cal_exx_energy(this->kspw_psi);
327+
dexx = exx_helper.cal_exx_energy(this->stp.psi_t);
351328
}
352-
exx_helper.set_psi(this->kspw_psi);
329+
exx_helper.set_psi(this->stp.psi_t);
353330
if (PARAM.inp.exx_thr_type == "energy")
354331
{
355-
dexx -= exx_helper.cal_exx_energy(this->kspw_psi);
332+
dexx -= exx_helper.cal_exx_energy(this->stp.psi_t);
356333
// std::cout << "dexx = " << dexx << std::endl;
357334
}
358335
bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr;
@@ -373,7 +350,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
373350
}
374351
else
375352
{
376-
exx_helper.set_psi(this->kspw_psi);
353+
exx_helper.set_psi(this->stp.psi_t);
377354
}
378355
}
379356

@@ -394,7 +371,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
394371
}
395372

396373
// the output quantities
397-
ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi,
374+
ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu,
398375
this->kv, this->pw_wfc, PARAM.inp);
399376
}
400377

@@ -409,24 +386,16 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
409386
// sunliang 2025-04-10
410387
if (PARAM.inp.out_elf[0] > 0)
411388
{
412-
this->ESolver_KS<T, Device>::psi = new psi::Psi<T>(this->psi[0]);
389+
this->ESolver_KS<T, Device>::psi = new psi::Psi<T>(this->stp.psi_cpu[0]);
413390
}
414391

415392
// Call 'after_scf' of ESolver_KS
416393
ESolver_KS<T, Device>::after_scf(ucell, istep, conv_esolver);
417394

418-
// Transfer data from GPU to CPU in pw basis
419-
if (this->device == base_device::GpuDevice)
420-
{
421-
castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(),
422-
this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(),
423-
this->psi[0].size());
424-
}
425-
426395
// Output quantities
427396
ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc,
428-
this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi,
429-
this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp);
397+
this->pw_rho, this->pw_rhod, this->pw_big, this->stp,
398+
this->ctx, this->device, this->Pgrid, PARAM.inp);
430399

431400
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
432401
}
@@ -442,39 +411,25 @@ void ESolver_KS_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix& fo
442411
{
443412
Forces<double, Device> ff(ucell.nat);
444413

445-
if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single")
446-
{
447-
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
448-
}
449-
450-
// Refresh __kspw_psi
451-
this->__kspw_psi = PARAM.inp.precision == "single"
452-
? new psi::Psi<std::complex<double>, Device>(this->kspw_psi[0])
453-
: reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->kspw_psi);
414+
// mohan add 2025-10-12
415+
this->stp.update_psi_d();
454416

455417
// Calculate forces
456418
ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm,
457419
&this->sf, this->solvent, &this->locpp, &this->ppcell,
458-
&this->kv, this->pw_wfc, this->__kspw_psi);
420+
&this->kv, this->pw_wfc, this->stp.psi_d);
459421
}
460422

461423
template <typename T, typename Device>
462424
void ESolver_KS_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress)
463425
{
464426
Stress_PW<double, Device> ss(this->pelec);
465427

466-
if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single")
467-
{
468-
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
469-
}
470-
471-
// Refresh __kspw_psi
472-
this->__kspw_psi = PARAM.inp.precision == "single"
473-
? new psi::Psi<std::complex<double>, Device>(this->kspw_psi[0])
474-
: reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->kspw_psi);
428+
// mohan add 2025-10-12
429+
this->stp.update_psi_d();
475430

476431
ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod,
477-
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->__kspw_psi);
432+
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.psi_d);
478433

479434
// external stress
480435
double unit_transform = 0.0;
@@ -492,9 +447,8 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
492447
ESolver_KS<T, Device>::after_all_runners(ucell);
493448

494449
ModuleIO::ctrl_runner_pw<T, Device>(ucell, this->pelec, this->pw_wfc,
495-
this->pw_rho, this->pw_rhod, this->chr, this->kv, this->psi,
496-
this->kspw_psi, this->__kspw_psi, this->sf,
497-
this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp);
450+
this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp,
451+
this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp);
498452

499453
elecstate::teardown_estate_pw<T, Device>(this->pelec, this->vsep_cell);
500454

0 commit comments

Comments
 (0)