Skip to content

Commit 30c9b28

Browse files
authored
Refactor: refactor the relevant code for psi initializer (#5474)
* Refactor the relevant code for psi initializer * Centralize all psi initializes code into WFInit Class * fix bug * fix bug * fix bug * fix 185-sdft bug * fix compiler bug * refactor hsolver
1 parent da1baf5 commit 30c9b28

File tree

10 files changed

+440
-374
lines changed

10 files changed

+440
-374
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,28 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
359359
}
360360
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
361361

362+
//---------------------------------------------------------------------------------------------------------------
363+
//---------------------------------for psi init guess!!!!--------------------------------------------------------
364+
//---------------------------------------------------------------------------------------------------------------
365+
if (!PARAM.inp.psi_initializer && PARAM.inp.basis_type == "pw" && this->init_psi == false)
366+
{
367+
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
368+
{
369+
//! Update Hamiltonian from other kpoint to the given one
370+
this->p_hamilt->updateHk(ik);
371+
372+
//! Fix the wavefunction to initialize at given kpoint
373+
this->kspw_psi->fix_k(ik);
374+
375+
/// for psi init guess!!!!
376+
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(this->kspw_psi), this->pw_wfc, &this->wf, this->p_hamilt);
377+
}
378+
}
379+
//---------------------------------------------------------------------------------------------------------------
380+
//---------------------------------END: for psi init guess!!!!--------------------------------------------------------
381+
//---------------------------------------------------------------------------------------------------------------
382+
362383
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
363-
&this->wf,
364384
PARAM.inp.calculation,
365385
PARAM.inp.basis_type,
366386
PARAM.inp.ks_solver,
@@ -370,8 +390,7 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
370390
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
371391
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
372392
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
373-
hsolver::DiagoIterAssist<T, Device>::need_subspace,
374-
this->init_psi);
393+
hsolver::DiagoIterAssist<T, Device>::need_subspace);
375394

376395
hsolver_pw_obj.solve(this->p_hamilt,
377396
this->kspw_psi[0],

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,35 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
190190
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
191191
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
192192

193+
//---------------------------------------------------------------------------------------------------------------
194+
//---------------------------------for psi init guess!!!!--------------------------------------------------------
195+
//---------------------------------------------------------------------------------------------------------------
196+
if (!PARAM.inp.psi_initializer && PARAM.inp.basis_type == "pw" && this->init_psi == false)
197+
{
198+
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
199+
{
200+
//! Update Hamiltonian from other kpoint to the given one
201+
this->p_hamilt->updateHk(ik);
202+
203+
if (this->kspw_psi->get_nbands() > 0 && GlobalV::MY_STOGROUP == 0)
204+
{
205+
//! Fix the wavefunction to initialize at given kpoint
206+
this->kspw_psi->fix_k(ik);
207+
208+
/// for psi init guess!!!!
209+
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(this->kspw_psi), this->pw_wfc, &this->wf, this->p_hamilt);
210+
}
211+
212+
}
213+
}
214+
//---------------------------------------------------------------------------------------------------------------
215+
//---------------------------------END: for psi init guess!!!!--------------------------------------------------------
216+
//---------------------------------------------------------------------------------------------------------------
217+
218+
193219
// hsolver only exists in this function
194220
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(&this->kv,
195221
this->pw_wfc,
196-
&this->wf,
197222
this->stowf,
198223
this->stoche,
199224
this->p_hamilt_sto,
@@ -206,8 +231,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
206231
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
207232
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
208233
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
209-
hsolver::DiagoIterAssist<T, Device>::need_subspace,
210-
this->init_psi);
234+
hsolver::DiagoIterAssist<T, Device>::need_subspace);
211235

212236
hsolver_pw_sdft_obj.solve(this->p_hamilt,
213237
this->kspw_psi[0],

source/module_hamilt_pw/hamilt_pwdft/wfinit.cpp

Lines changed: 102 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "module_base/timer.h"
55
#include "module_base/tool_quit.h"
66
#include "module_hsolver/diago_iter_assist.h"
7+
#include "module_parameter/parameter.h"
78
#include "module_psi/psi_initializer_atomic.h"
89
#include "module_psi/psi_initializer_atomic_random.h"
910
#include "module_psi/psi_initializer_nao.h"
@@ -38,9 +39,10 @@ void WFInit<T, Device>::prepare_init(Structure_Factor* p_sf,
3839
#endif
3940
pseudopot_cell_vnl* p_ppcell)
4041
{
41-
if (!this->use_psiinitializer) {
42+
if (!this->use_psiinitializer)
43+
{
4244
return;
43-
}
45+
}
4446
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
4547
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
4648
ModuleBase::timer::tick("WFInit", "prepare_init");
@@ -136,7 +138,9 @@ void WFInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
136138
template <typename T, typename Device>
137139
void WFInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf)
138140
{
139-
if (this->use_psiinitializer) {} // do not need to do anything because the interpolate table is unchanged
141+
if (this->use_psiinitializer)
142+
{
143+
} // do not need to do anything because the interpolate table is unchanged
140144
else // old initialization method, used in EXX calculation
141145
{
142146
this->p_wf->init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
@@ -150,95 +154,121 @@ void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
150154
hamilt::Hamilt<T, Device>* p_hamilt,
151155
std::ofstream& ofs_running)
152156
{
153-
if (!this->use_psiinitializer) { return; }
154157
ModuleBase::timer::tick("WFInit", "initialize_psi");
155-
// if psig is not allocated before, allocate it
156-
if (!this->psi_init->psig_use_count()) { this->psi_init->allocate(/*psig_only=*/true); }
157-
158-
// loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
159-
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
160-
for (int ik = 0; ik < this->pw_wfc->nks; ik++)
158+
159+
if (PARAM.inp.psi_initializer)
161160
{
162-
//! Fix the wavefunction to initialize at given kpoint
163-
psi->fix_k(ik);
161+
// if psig is not allocated before, allocate it
162+
if (!this->psi_init->psig_use_count())
163+
{
164+
this->psi_init->allocate(/*psig_only=*/true);
165+
}
164166

165-
//! Update Hamiltonian from other kpoint to the given one
166-
p_hamilt->updateHk(ik);
167+
// loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
168+
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
169+
for (int ik = 0; ik < this->pw_wfc->nks; ik++)
170+
{
171+
//! Fix the wavefunction to initialize at given kpoint
172+
psi->fix_k(ik);
167173

168-
//! Project atomic orbitals on |k+G> planewave basis, where k is wavevector of kpoint
169-
//! and G is wavevector of the peroiodic part of the Bloch function
170-
this->psi_init->proj_ao_onkG(ik);
174+
//! Update Hamiltonian from other kpoint to the given one
175+
p_hamilt->updateHk(ik);
171176

172-
//! psi_initializer manages memory of psig with shared pointer,
173-
//! its access to use is shared here via weak pointer
174-
//! therefore once the psi_initializer is destructed, psig will be destructed, too
175-
//! this way, we can avoid memory leak and undefined behavior
176-
std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();
177+
//! Project atomic orbitals on |k+G> planewave basis, where k is wavevector of kpoint
178+
//! and G is wavevector of the peroiodic part of the Bloch function
179+
this->psi_init->proj_ao_onkG(ik);
177180

178-
if (psig.expired())
179-
{
180-
ModuleBase::WARNING_QUIT("WFInit::initialize_psi", "psig lifetime is expired");
181-
}
181+
//! psi_initializer manages memory of psig with shared pointer,
182+
//! its access to use is shared here via weak pointer
183+
//! therefore once the psi_initializer is destructed, psig will be destructed, too
184+
//! this way, we can avoid memory leak and undefined behavior
185+
std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();
182186

183-
//! to use psig, we need to lock it to get a shared pointer version,
184-
//! then switch kpoint of psig to the given one
185-
auto psig_ = psig.lock();
186-
// CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
187-
// so we can only allocate memory for one kpoint with the maximal number of pw
188-
// over all kpoints, then the memory space will be always enough. Then for each
189-
// kpoint, the psig is calculated in an overwrite manner.
190-
const int ik_psig = (psig_->get_nk() == 1) ? 0 : ik;
191-
psig_->fix_k(ik_psig);
192-
193-
std::vector<typename GetTypeReal<T>::type> etatom(psig_->get_nbands(), 0.0);
194-
195-
// then adjust dimension from psig to psi
196-
// either by matrix-multiplication or by copying-discarding
197-
if (this->psi_init->method() != "random")
198-
{
199-
// lcaoinpw and pw share the same esolver. In the future, we will have different esolver
200-
if (((this->ks_solver == "cg") || (this->ks_solver == "lapack")) && (this->basis_type == "pw"))
187+
if (psig.expired())
201188
{
202-
// the following function is only run serially, to be improved
203-
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
204-
psig_->get_pointer(),
205-
psig_->get_nbands(),
206-
psig_->get_nbasis(),
207-
*(kspw_psi),
208-
etatom.data());
209-
continue;
189+
ModuleBase::WARNING_QUIT("WFInit::initialize_psi", "psig lifetime is expired");
210190
}
211-
else if ((this->ks_solver == "lapack") && (this->basis_type == "lcao_in_pw"))
191+
192+
//! to use psig, we need to lock it to get a shared pointer version,
193+
//! then switch kpoint of psig to the given one
194+
auto psig_ = psig.lock();
195+
// CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
196+
// so we can only allocate memory for one kpoint with the maximal number of pw
197+
// over all kpoints, then the memory space will be always enough. Then for each
198+
// kpoint, the psig is calculated in an overwrite manner.
199+
const int ik_psig = (psig_->get_nk() == 1) ? 0 : ik;
200+
psig_->fix_k(ik_psig);
201+
202+
std::vector<typename GetTypeReal<T>::type> etatom(psig_->get_nbands(), 0.0);
203+
204+
// then adjust dimension from psig to psi
205+
// either by matrix-multiplication or by copying-discarding
206+
if (this->psi_init->method() != "random")
212207
{
213-
if (ik == 0)
208+
// lcaoinpw and pw share the same esolver. In the future, we will have different esolver
209+
if (((this->ks_solver == "cg") || (this->ks_solver == "lapack")) && (this->basis_type == "pw"))
210+
{
211+
// the following function is only run serially, to be improved
212+
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
213+
psig_->get_pointer(),
214+
psig_->get_nbands(),
215+
psig_->get_nbasis(),
216+
*(kspw_psi),
217+
etatom.data());
218+
continue;
219+
}
220+
else if ((this->ks_solver == "lapack") && (this->basis_type == "lcao_in_pw"))
214221
{
215-
ofs_running << " START WAVEFUNCTION: LCAO_IN_PW, psi initialization skipped " << std::endl;
222+
if (ik == 0)
223+
{
224+
ofs_running << " START WAVEFUNCTION: LCAO_IN_PW, psi initialization skipped " << std::endl;
225+
}
226+
continue;
216227
}
217-
continue;
228+
// else the case is davidson
218229
}
219-
// else the case is davidson
220-
}
221-
else
222-
{
223-
if (this->ks_solver == "cg")
230+
else
224231
{
225-
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt, *(psig_), *(kspw_psi), etatom.data());
226-
continue;
232+
if (this->ks_solver == "cg")
233+
{
234+
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt, *(psig_), *(kspw_psi), etatom.data());
235+
continue;
236+
}
237+
// else the case is davidson
227238
}
228-
// else the case is davidson
229-
}
230239

231-
// for the Davidson method, we just copy the wavefunction (partially)
232-
for (int iband = 0; iband < kspw_psi->get_nbands(); iband++)
233-
{
234-
for (int ibasis = 0; ibasis < kspw_psi->get_nbasis(); ibasis++)
240+
// for the Davidson method, we just copy the wavefunction (partially)
241+
for (int iband = 0; iband < kspw_psi->get_nbands(); iband++)
235242
{
236-
(*(kspw_psi))(iband, ibasis) = (*psig_)(iband, ibasis);
243+
for (int ibasis = 0; ibasis < kspw_psi->get_nbasis(); ibasis++)
244+
{
245+
(*(kspw_psi))(iband, ibasis) = (*psig_)(iband, ibasis);
246+
}
237247
}
248+
} // end k-point loop
249+
250+
if (this->basis_type != "lcao_in_pw")
251+
{
252+
this->psi_init->deallocate_psig();
238253
}
239-
} // end k-point loop
254+
}
255+
else
256+
{
257+
// if (PARAM.inp.basis_type == "pw")
258+
// {
259+
// for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
260+
// {
261+
// //! Update Hamiltonian from other kpoint to the given one
262+
// p_hamilt->updateHk(ik);
263+
264+
// //! Fix the wavefunction to initialize at given kpoint
265+
// kspw_psi->fix_k(ik);
240266

241-
if (this->basis_type != "lcao_in_pw") { this->psi_init->deallocate_psig(); }
267+
// /// for psi init guess!!!!
268+
// hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *kspw_psi, this->pw_wfc, this->p_wf, p_hamilt);
269+
// }
270+
// }
271+
}
242272

243273
ModuleBase::timer::tick("WFInit", "initialize_psi");
244274
}

source/module_hamilt_pw/hamilt_pwdft/wfinit.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class WFInit
8484
std::string basis_type = "none";
8585
// pw basis
8686
ModulePW::PW_Basis_K* pw_wfc = nullptr;
87+
88+
Device* ctx = {};
8789
};
8890

8991
} // namespace psi

0 commit comments

Comments
 (0)