11#include " hsolver_pw.h"
22
3- #include " module_parameter/parameter.h"
43#include " module_base/global_variable.h"
54#include " module_base/timer.h"
65#include " module_base/tool_quit.h"
76#include " module_elecstate/elecstate_pw.h"
87#include " module_hamilt_general/hamilt.h"
98#include " module_hamilt_pw/hamilt_pwdft/wavefunc.h"
10- #include " module_psi/psi.h"
11-
129#include " module_hsolver/diag_comm_info.h"
1310#include " module_hsolver/diago_bpcg.h"
1411#include " module_hsolver/diago_cg.h"
1512#include " module_hsolver/diago_dav_subspace.h"
1613#include " module_hsolver/diago_david.h"
1714#include " module_hsolver/diago_iter_assist.h"
15+ #include " module_parameter/parameter.h"
16+ #include " module_psi/psi.h"
1817
1918#include < algorithm>
2019#include < vector>
@@ -209,26 +208,30 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
209208#endif
210209
211210template <typename T, typename Device>
212- void HSolverPW<T, Device>::cal_ethr_band(const double & wk, const double * wg, const double & ethr, std::vector<double >& ethrs)
211+ void HSolverPW<T, Device>::cal_ethr_band(const double & wk,
212+ const double * wg,
213+ const double & ethr,
214+ std::vector<double >& ethrs)
213215{
214216 // threshold for classifying occupied and unoccupied bands
215217 const double occ_threshold = 1e-2 ;
216218 // diagonalization threshold limitation for unoccupied bands
217219 const double ethr_limit = 1e-5 ;
218- if (wk > 0.0 )
220+ if (wk > 0.0 )
219221 {
220222 // Note: the idea of threshold for unoccupied bands (1e-5) comes from QE
221- // In ABACUS, We applied a smoothing process to this truncation to avoid abrupt changes in energy errors between different bands.
223+ // In ABACUS, We applied a smoothing process to this truncation to avoid abrupt changes in energy errors between
224+ // different bands.
222225 const double ethr_unocc = std::max (ethr_limit, ethr);
223226 for (int i = 0 ; i < ethrs.size (); i++)
224227 {
225228 double band_weight = wg[i] / wk;
226229 if (band_weight > occ_threshold)
227230 {
228- ethrs[i] = ethr;
231+ ethrs[i] = ethr;
229232 }
230- else if (band_weight > ethr_limit)
231- {// similar energy difference for different bands when band_weight in range [1e-5, 1e-2]
233+ else if (band_weight > ethr_limit)
234+ { // similar energy difference for different bands when band_weight in range [1e-5, 1e-2]
232235 ethrs[i] = std::min (ethr_unocc, ethr / band_weight);
233236 }
234237 else
@@ -258,6 +261,27 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
258261 ModuleBase::TITLE (" HSolverPW" , " solve" );
259262 ModuleBase::timer::tick (" HSolverPW" , " solve" );
260263
264+ // ---------------------------------------------------------------------------------------------------------------
265+ // ---------------------------------for psi init guess!!!!--------------------------------------------------------
266+ // ---------------------------------------------------------------------------------------------------------------
267+ if (!PARAM.inp .psi_initializer && !this ->initialed_psi && this ->basis_type == " pw" )
268+ {
269+ for (int ik = 0 ; ik < this ->wfc_basis ->nks ; ++ik)
270+ {
271+ // / update H(k) for each k point
272+ pHamilt->updateHk (ik);
273+
274+ // / update psi pointer for each k point
275+ psi.fix_k (ik);
276+
277+ // / for psi init guess!!!!
278+ hamilt::diago_PAO_in_pw_k2 (this ->ctx , ik, psi, this ->wfc_basis , this ->pwf , pHamilt);
279+ }
280+ }
281+ // ---------------------------------------------------------------------------------------------------------------
282+ // ---------------------------------------------------------------------------------------------------------------
283+ // ---------------------------------------------------------------------------------------------------------------
284+
261285 this ->rank_in_pool = rank_in_pool_in;
262286 this ->nproc_in_pool = nproc_in_pool_in;
263287
@@ -283,19 +307,20 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
283307 this ->paw_func_in_kloop (ik);
284308#endif
285309
286- this ->updatePsiK (pHamilt, psi, ik);
310+ // / update psi pointer for each k point
311+ psi.fix_k (ik);
287312
288313 // template add precondition calculating here
289314 update_precondition (precondition, ik, this ->wfc_basis ->npwk [ik], Real (pes->pot ->get_vl_of_0 ()));
290-
315+
291316 // only dav_subspace method used smooth threshold for all bands now,
292317 // for other methods, this trick can be added in the future to accelerate calculation without accuracy loss.
293- if (this ->method == " dav_subspace" )
318+ if (this ->method == " dav_subspace" )
294319 {
295320 this ->cal_ethr_band (pes->klist ->wk [ik],
296- &pes->wg (ik, 0 ),
297- DiagoIterAssist<T, Device>::PW_DIAG_THR,
298- ethr_band);
321+ &pes->wg (ik, 0 ),
322+ DiagoIterAssist<T, Device>::PW_DIAG_THR,
323+ ethr_band);
299324 }
300325
301326#ifdef USE_PAW
@@ -309,8 +334,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
309334 {
310335 GlobalV::ofs_running << " Average iterative diagonalization steps for k-points " << ik
311336 << " is: " << DiagoIterAssist<T, Device>::avg_iter
312- << " ; where current threshold is: " << this ->diag_thr
313- << " . " << std::endl;
337+ << " ; where current threshold is: " << this ->diag_thr << " . " << std::endl;
314338 DiagoIterAssist<T, Device>::avg_iter = 0.0 ;
315339 }
316340 // / calculate the contribution of Psi for charge density rho
@@ -347,17 +371,6 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
347371 }
348372}
349373
350- template <typename T, typename Device>
351- void HSolverPW<T, Device>::updatePsiK(hamilt::Hamilt<T, Device>* pHamilt, psi::Psi<T, Device>& psi, const int ik)
352- {
353- psi.fix_k (ik);
354- if (!PARAM.inp .psi_initializer && !this ->initialed_psi && this ->basis_type == " pw" )
355- {
356- hamilt::diago_PAO_in_pw_k2 (this ->ctx , ik, psi, this ->wfc_basis , this ->pwf , pHamilt);
357- }
358- /* lcao_in_pw now is based on newly implemented psi initializer, so it does not appear here*/
359- }
360-
361374template <typename T, typename Device>
362375void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
363376 psi::Psi<T, Device>& psi,
@@ -484,16 +497,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
484497 {
485498 auto ngk_pointer = psi.get_ngk_pointer ();
486499 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
487- auto hpsi_func = [hm, ngk_pointer](T *psi_in,
488- T *hpsi_out,
489- const int ld_psi,
490- const int nvec) {
500+ auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
491501 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
492502
493503 // Convert "pointer data stucture" to a psi::Psi object
494504 auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer);
495505
496- psi::Range bands_range (true , 0 , 0 , nvec- 1 );
506+ psi::Range bands_range (true , 0 , 0 , nvec - 1 );
497507
498508 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
499509 hpsi_info info (&psi_iter_wrapper, bands_range, hpsi_out);
@@ -513,8 +523,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
513523 this ->need_subspace ,
514524 comm_info);
515525
516- DiagoIterAssist<T, Device>::avg_iter
517- += static_cast < double >( dav_subspace.diag (hpsi_func, psi.get_pointer (), psi.get_nbasis (), eigenvalue, this ->ethr_band .data (), scf));
526+ DiagoIterAssist<T, Device>::avg_iter += static_cast < double >(
527+ dav_subspace.diag (hpsi_func, psi.get_pointer (), psi.get_nbasis (), eigenvalue, this ->ethr_band .data (), scf));
518528 }
519529 else if (this ->method == " dav" )
520530 {
@@ -533,23 +543,20 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
533543 // dimensions of matrix to be solved
534544 const int dim = psi.get_current_nbas (); // / dimension of matrix
535545 const int nband = psi.get_nbands (); // / number of eigenpairs sought
536- const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
546+ const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
537547
538548 // Davidson matrix-blockvector functions
539549
540550 auto ngk_pointer = psi.get_ngk_pointer ();
541551 // / wrap hpsi into lambda function, Matrix \times blockvector
542552 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
543- auto hpsi_func = [hm, ngk_pointer](T *psi_in,
544- T *hpsi_out,
545- const int ld_psi,
546- const int nvec) {
553+ auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
547554 ModuleBase::timer::tick (" David" , " hpsi_func" );
548555
549556 // Convert pointer of psi_in to a psi::Psi object
550557 auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer);
551558
552- psi::Range bands_range (true , 0 , 0 , nvec- 1 );
559+ psi::Range bands_range (true , 0 , 0 , nvec - 1 );
553560
554561 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
555562 hpsi_info info (&psi_iter_wrapper, bands_range, hpsi_out);
@@ -561,23 +568,19 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
561568 // / wrap spsi into lambda function, Matrix \times blockvector
562569 // / spsi(X, SX, ld, nvec)
563570 // / ld is leading dimension of psi and spsi
564- auto spsi_func = [hm](const T* psi_in, T* spsi_out,
565- const int ld_psi, // Leading dimension of psi and spsi.
566- const int nvec // Number of vectors(bands)
567- ){
571+ auto spsi_func = [hm](const T* psi_in,
572+ T* spsi_out,
573+ const int ld_psi, // Leading dimension of psi and spsi.
574+ const int nvec // Number of vectors(bands)
575+ ) {
568576 ModuleBase::timer::tick (" David" , " spsi_func" );
569577 // sPsi determines S=I or not by PARAM.globalv.use_uspp inside
570578 // sPsi(psi, spsi, nrow, npw, nbands)
571579 hm->sPsi (psi_in, spsi_out, ld_psi, ld_psi, nvec);
572580 ModuleBase::timer::tick (" David" , " spsi_func" );
573581 };
574582
575- DiagoDavid<T, Device> david (pre_condition.data (),
576- nband,
577- dim,
578- PARAM.inp .pw_diag_ndim ,
579- this ->use_paw ,
580- comm_info);
583+ DiagoDavid<T, Device> david (pre_condition.data (), nband, dim, PARAM.inp .pw_diag_ndim , this ->use_paw , comm_info);
581584 // do diag and add davidson iteration counts up to avg_iter
582585 DiagoIterAssist<T, Device>::avg_iter += static_cast <double >(david.diag (hpsi_func,
583586 spsi_func,
@@ -593,7 +596,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
593596}
594597
595598template <typename T, typename Device>
596- void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0)
599+ void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag,
600+ const int ik,
601+ const int npw,
602+ const Real vl_of_0)
597603{
598604 h_diag.assign (h_diag.size (), 1.0 );
599605 int precondition_type = 2 ;
@@ -646,8 +652,7 @@ void HSolverPW<T, Device>::output_iterInfo()
646652 {
647653 GlobalV::ofs_running << " Average iterative diagonalization steps: "
648654 << DiagoIterAssist<T, Device>::avg_iter / this ->wfc_basis ->nks
649- << " ; where current threshold is: " << this ->diag_thr << " . "
650- << std::endl;
655+ << " ; where current threshold is: " << this ->diag_thr << " . " << std::endl;
651656 // reset avg_iter
652657 DiagoIterAssist<T, Device>::avg_iter = 0.0 ;
653658 }
0 commit comments