Skip to content

Commit 2649152

Browse files
authored
Refactor: refactor psi init & wfinit class (#5533)
* refactor psi init * remove useless code
1 parent 8b98a20 commit 2649152

File tree

6 files changed

+45
-68
lines changed

6 files changed

+45
-68
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,16 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep)
312312
// does is only to initialize for once...
313313
if (((PARAM.inp.init_wfc == "random") && (istep == 0)) || (PARAM.inp.init_wfc != "random"))
314314
{
315-
this->p_wf_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
315+
this->p_wf_init->initialize_psi(this->psi,
316+
this->kspw_psi,
317+
this->p_hamilt,
318+
GlobalV::ofs_running,
319+
this->already_initpsi);
320+
321+
if (this->already_initpsi == false)
322+
{
323+
this->already_initpsi = true;
324+
}
316325
}
317326
}
318327

@@ -359,27 +368,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
359368
}
360369
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
361370

362-
//---------------------------------------------------------------------------------------------------------------
363-
//---------------------------------for psi init guess!!!!--------------------------------------------------------
364-
//---------------------------------------------------------------------------------------------------------------
365-
if (!PARAM.inp.psi_initializer && PARAM.inp.basis_type == "pw" && this->init_psi == false)
366-
{
367-
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
368-
{
369-
//! Update Hamiltonian from other kpoint to the given one
370-
this->p_hamilt->updateHk(ik);
371-
372-
//! Fix the wavefunction to initialize at given kpoint
373-
this->kspw_psi->fix_k(ik);
374-
375-
/// for psi init guess!!!!
376-
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(this->kspw_psi), this->pw_wfc, &this->wf, this->p_hamilt);
377-
}
378-
}
379-
//---------------------------------------------------------------------------------------------------------------
380-
//---------------------------------END: for psi init guess!!!!--------------------------------------------------------
381-
//---------------------------------------------------------------------------------------------------------------
382-
383371
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
384372
PARAM.inp.calculation,
385373
PARAM.inp.basis_type,
@@ -400,8 +388,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
400388
GlobalV::NPROC_IN_POOL,
401389
skip_charge);
402390

403-
this->init_psi = true;
404-
405391
Symmetry_rho srho;
406392
for (int is = 0; is < PARAM.inp.nspin; is++)
407393
{

source/module_esolver/esolver_ks_pw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
6969

7070
psi::Psi<std::complex<double>, Device>* __kspw_psi = nullptr;
7171

72-
bool init_psi = false;
72+
bool already_initpsi = false;
7373

7474
using castmem_2d_d2h_op
7575
= base_device::memory::cast_memory_op<std::complex<double>, T, base_device::DEVICE_CPU, Device>;

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -190,32 +190,6 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
190190
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
191191
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
192192

193-
//---------------------------------------------------------------------------------------------------------------
194-
//---------------------------------for psi init guess!!!!--------------------------------------------------------
195-
//---------------------------------------------------------------------------------------------------------------
196-
if (!PARAM.inp.psi_initializer && PARAM.inp.basis_type == "pw" && this->init_psi == false)
197-
{
198-
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
199-
{
200-
//! Update Hamiltonian from other kpoint to the given one
201-
this->p_hamilt->updateHk(ik);
202-
203-
if (this->kspw_psi->get_nbands() > 0 && GlobalV::MY_STOGROUP == 0)
204-
{
205-
//! Fix the wavefunction to initialize at given kpoint
206-
this->kspw_psi->fix_k(ik);
207-
208-
/// for psi init guess!!!!
209-
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(this->kspw_psi), this->pw_wfc, &this->wf, this->p_hamilt);
210-
}
211-
212-
}
213-
}
214-
//---------------------------------------------------------------------------------------------------------------
215-
//---------------------------------END: for psi init guess!!!!--------------------------------------------------------
216-
//---------------------------------------------------------------------------------------------------------------
217-
218-
219193
// hsolver only exists in this function
220194
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(&this->kv,
221195
this->pw_wfc,
@@ -242,7 +216,6 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
242216
istep,
243217
iter,
244218
skip_charge);
245-
this->init_psi = true;
246219

247220
// set_diagethr need it
248221
this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne;

source/module_esolver/pw_init_after_vc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void ESolver_KS_PW<T, Device>::init_after_vc(const Input_para& inp, UnitCell& uc
8989
this->pw_wfc->collect_local_pw(inp.erf_ecut,
9090
inp.erf_height,
9191
inp.erf_sigma);
92-
this->init_psi = false;
92+
this->already_initpsi = false;
9393

9494
delete this->pelec;
9595
this->pelec

source/module_hamilt_pw/hamilt_pwdft/wfinit.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ template <typename T, typename Device>
152152
void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
153153
psi::Psi<T, Device>* kspw_psi,
154154
hamilt::Hamilt<T, Device>* p_hamilt,
155-
std::ofstream& ofs_running)
155+
std::ofstream& ofs_running,
156+
const bool is_already_initpsi)
156157
{
157158
ModuleBase::timer::tick("WFInit", "initialize_psi");
158159

@@ -254,20 +255,35 @@ void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
254255
}
255256
else
256257
{
257-
// if (PARAM.inp.basis_type == "pw")
258-
// {
259-
// for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
260-
// {
261-
// //! Update Hamiltonian from other kpoint to the given one
262-
// p_hamilt->updateHk(ik);
258+
//! note: is_already_initpsi will be false in init_after_vc when vc changes.
259+
if (PARAM.inp.basis_type == "pw" && is_already_initpsi == false)
260+
{
261+
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
262+
{
263+
//! Update Hamiltonian from other kpoint to the given one
264+
p_hamilt->updateHk(ik);
263265

264-
// //! Fix the wavefunction to initialize at given kpoint
265-
// kspw_psi->fix_k(ik);
266+
if (PARAM.inp.esolver_type == "sdft")
267+
{
268+
if (kspw_psi->get_nbands() > 0 && GlobalV::MY_STOGROUP == 0)
269+
{
270+
//! Fix the wavefunction to initialize at given kpoint
271+
kspw_psi->fix_k(ik);
266272

267-
// /// for psi init guess!!!!
268-
// hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *kspw_psi, this->pw_wfc, this->p_wf, p_hamilt);
269-
// }
270-
// }
273+
/// for psi init guess!!!!
274+
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(kspw_psi), this->pw_wfc, this->p_wf, p_hamilt);
275+
}
276+
}
277+
else
278+
{
279+
//! Fix the wavefunction to initialize at given kpoint
280+
kspw_psi->fix_k(ik);
281+
282+
/// for psi init guess!!!!
283+
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(kspw_psi), this->pw_wfc, this->p_wf, p_hamilt);
284+
}
285+
}
286+
}
271287
}
272288

273289
ModuleBase::timer::tick("WFInit", "initialize_psi");

source/module_hamilt_pw/hamilt_pwdft/wfinit.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@ class WFInit
4747
*
4848
* @param psi store the wavefunction
4949
* @param p_hamilt Hamiltonian operator
50-
* @param ofs_running output stream for running information
50+
* @param ofs_running output stream for running information
51+
* @param is_already_initpsi whether psi has been initialized
5152
*/
5253
void initialize_psi(Psi<std::complex<double>>* psi,
5354
psi::Psi<T, Device>* kspw_psi,
5455
hamilt::Hamilt<T, Device>* p_hamilt,
55-
std::ofstream& ofs_running);
56+
std::ofstream& ofs_running,
57+
const bool is_already_initpsi);
5658

5759
/**
5860
* @brief get the psi_initializer

0 commit comments

Comments
 (0)