Skip to content

Commit 10d4ef7

Browse files
committed
Refactor the relevant code for psi initializer
1 parent a2fdb95 commit 10d4ef7

File tree

3 files changed

+88
-61
lines changed

3 files changed

+88
-61
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
#include "hsolver_pw.h"
22

3-
#include "module_parameter/parameter.h"
43
#include "module_base/global_variable.h"
54
#include "module_base/timer.h"
65
#include "module_base/tool_quit.h"
76
#include "module_elecstate/elecstate_pw.h"
87
#include "module_hamilt_general/hamilt.h"
98
#include "module_hamilt_pw/hamilt_pwdft/wavefunc.h"
10-
#include "module_psi/psi.h"
11-
129
#include "module_hsolver/diag_comm_info.h"
1310
#include "module_hsolver/diago_bpcg.h"
1411
#include "module_hsolver/diago_cg.h"
1512
#include "module_hsolver/diago_dav_subspace.h"
1613
#include "module_hsolver/diago_david.h"
1714
#include "module_hsolver/diago_iter_assist.h"
15+
#include "module_parameter/parameter.h"
16+
#include "module_psi/psi.h"
1817

1918
#include <algorithm>
2019
#include <vector>
@@ -209,26 +208,30 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
209208
#endif
210209

211210
template <typename T, typename Device>
212-
void HSolverPW<T, Device>::cal_ethr_band(const double& wk, const double* wg, const double& ethr, std::vector<double>& ethrs)
211+
void HSolverPW<T, Device>::cal_ethr_band(const double& wk,
212+
const double* wg,
213+
const double& ethr,
214+
std::vector<double>& ethrs)
213215
{
214216
// threshold for classifying occupied and unoccupied bands
215217
const double occ_threshold = 1e-2;
216218
// diagonalization threshold limitation for unoccupied bands
217219
const double ethr_limit = 1e-5;
218-
if(wk > 0.0)
220+
if (wk > 0.0)
219221
{
220222
// Note: the idea of threshold for unoccupied bands (1e-5) comes from QE
221-
// In ABACUS, We applied a smoothing process to this truncation to avoid abrupt changes in energy errors between different bands.
223+
// In ABACUS, We applied a smoothing process to this truncation to avoid abrupt changes in energy errors between
224+
// different bands.
222225
const double ethr_unocc = std::max(ethr_limit, ethr);
223226
for (int i = 0; i < ethrs.size(); i++)
224227
{
225228
double band_weight = wg[i] / wk;
226229
if (band_weight > occ_threshold)
227230
{
228-
ethrs[i] = ethr;
231+
ethrs[i] = ethr;
229232
}
230-
else if(band_weight > ethr_limit)
231-
{// similar energy difference for different bands when band_weight in range [1e-5, 1e-2]
233+
else if (band_weight > ethr_limit)
234+
{ // similar energy difference for different bands when band_weight in range [1e-5, 1e-2]
232235
ethrs[i] = std::min(ethr_unocc, ethr / band_weight);
233236
}
234237
else
@@ -258,6 +261,27 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
258261
ModuleBase::TITLE("HSolverPW", "solve");
259262
ModuleBase::timer::tick("HSolverPW", "solve");
260263

264+
//---------------------------------------------------------------------------------------------------------------
265+
//---------------------------------for psi init guess!!!!--------------------------------------------------------
266+
//---------------------------------------------------------------------------------------------------------------
267+
if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw")
268+
{
269+
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
270+
{
271+
/// update H(k) for each k point
272+
pHamilt->updateHk(ik);
273+
274+
/// update psi pointer for each k point
275+
psi.fix_k(ik);
276+
277+
/// for psi init guess!!!!
278+
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt);
279+
}
280+
}
281+
//---------------------------------------------------------------------------------------------------------------
282+
//---------------------------------------------------------------------------------------------------------------
283+
//---------------------------------------------------------------------------------------------------------------
284+
261285
this->rank_in_pool = rank_in_pool_in;
262286
this->nproc_in_pool = nproc_in_pool_in;
263287

@@ -283,19 +307,20 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
283307
this->paw_func_in_kloop(ik);
284308
#endif
285309

286-
this->updatePsiK(pHamilt, psi, ik);
310+
/// update psi pointer for each k point
311+
psi.fix_k(ik);
287312

288313
// template add precondition calculating here
289314
update_precondition(precondition, ik, this->wfc_basis->npwk[ik], Real(pes->pot->get_vl_of_0()));
290-
315+
291316
// only dav_subspace method used smooth threshold for all bands now,
292317
// for other methods, this trick can be added in the future to accelerate calculation without accuracy loss.
293-
if (this->method == "dav_subspace")
318+
if (this->method == "dav_subspace")
294319
{
295320
this->cal_ethr_band(pes->klist->wk[ik],
296-
&pes->wg(ik, 0),
297-
DiagoIterAssist<T, Device>::PW_DIAG_THR,
298-
ethr_band);
321+
&pes->wg(ik, 0),
322+
DiagoIterAssist<T, Device>::PW_DIAG_THR,
323+
ethr_band);
299324
}
300325

301326
#ifdef USE_PAW
@@ -309,8 +334,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
309334
{
310335
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
311336
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
312-
<< " ; where current threshold is: " << this->diag_thr
313-
<< " . " << std::endl;
337+
<< " ; where current threshold is: " << this->diag_thr << " . " << std::endl;
314338
DiagoIterAssist<T, Device>::avg_iter = 0.0;
315339
}
316340
/// calculate the contribution of Psi for charge density rho
@@ -347,17 +371,6 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
347371
}
348372
}
349373

350-
template <typename T, typename Device>
351-
void HSolverPW<T, Device>::updatePsiK(hamilt::Hamilt<T, Device>* pHamilt, psi::Psi<T, Device>& psi, const int ik)
352-
{
353-
psi.fix_k(ik);
354-
if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw")
355-
{
356-
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt);
357-
}
358-
/* lcao_in_pw now is based on newly implemented psi initializer, so it does not appear here*/
359-
}
360-
361374
template <typename T, typename Device>
362375
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
363376
psi::Psi<T, Device>& psi,
@@ -484,16 +497,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
484497
{
485498
auto ngk_pointer = psi.get_ngk_pointer();
486499
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
487-
auto hpsi_func = [hm, ngk_pointer](T *psi_in,
488-
T *hpsi_out,
489-
const int ld_psi,
490-
const int nvec) {
500+
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
491501
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
492502

493503
// Convert "pointer data stucture" to a psi::Psi object
494504
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
495505

496-
psi::Range bands_range(true, 0, 0, nvec-1);
506+
psi::Range bands_range(true, 0, 0, nvec - 1);
497507

498508
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
499509
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
@@ -513,8 +523,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
513523
this->need_subspace,
514524
comm_info);
515525

516-
DiagoIterAssist<T, Device>::avg_iter
517-
+= static_cast<double>(dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band.data(), scf));
526+
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
527+
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band.data(), scf));
518528
}
519529
else if (this->method == "dav")
520530
{
@@ -533,23 +543,20 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
533543
// dimensions of matrix to be solved
534544
const int dim = psi.get_current_nbas(); /// dimension of matrix
535545
const int nband = psi.get_nbands(); /// number of eigenpairs sought
536-
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
546+
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
537547

538548
// Davidson matrix-blockvector functions
539549

540550
auto ngk_pointer = psi.get_ngk_pointer();
541551
/// wrap hpsi into lambda function, Matrix \times blockvector
542552
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
543-
auto hpsi_func = [hm, ngk_pointer](T *psi_in,
544-
T *hpsi_out,
545-
const int ld_psi,
546-
const int nvec) {
553+
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
547554
ModuleBase::timer::tick("David", "hpsi_func");
548555

549556
// Convert pointer of psi_in to a psi::Psi object
550557
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
551558

552-
psi::Range bands_range(true, 0, 0, nvec-1);
559+
psi::Range bands_range(true, 0, 0, nvec - 1);
553560

554561
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
555562
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
@@ -561,23 +568,19 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
561568
/// wrap spsi into lambda function, Matrix \times blockvector
562569
/// spsi(X, SX, ld, nvec)
563570
/// ld is leading dimension of psi and spsi
564-
auto spsi_func = [hm](const T* psi_in, T* spsi_out,
565-
const int ld_psi, // Leading dimension of psi and spsi.
566-
const int nvec // Number of vectors(bands)
567-
){
571+
auto spsi_func = [hm](const T* psi_in,
572+
T* spsi_out,
573+
const int ld_psi, // Leading dimension of psi and spsi.
574+
const int nvec // Number of vectors(bands)
575+
) {
568576
ModuleBase::timer::tick("David", "spsi_func");
569577
// sPsi determines S=I or not by PARAM.globalv.use_uspp inside
570578
// sPsi(psi, spsi, nrow, npw, nbands)
571579
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
572580
ModuleBase::timer::tick("David", "spsi_func");
573581
};
574582

575-
DiagoDavid<T, Device> david(pre_condition.data(),
576-
nband,
577-
dim,
578-
PARAM.inp.pw_diag_ndim,
579-
this->use_paw,
580-
comm_info);
583+
DiagoDavid<T, Device> david(pre_condition.data(), nband, dim, PARAM.inp.pw_diag_ndim, this->use_paw, comm_info);
581584
// do diag and add davidson iteration counts up to avg_iter
582585
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(david.diag(hpsi_func,
583586
spsi_func,
@@ -593,7 +596,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
593596
}
594597

595598
template <typename T, typename Device>
596-
void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0)
599+
void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag,
600+
const int ik,
601+
const int npw,
602+
const Real vl_of_0)
597603
{
598604
h_diag.assign(h_diag.size(), 1.0);
599605
int precondition_type = 2;
@@ -646,8 +652,7 @@ void HSolverPW<T, Device>::output_iterInfo()
646652
{
647653
GlobalV::ofs_running << "Average iterative diagonalization steps: "
648654
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
649-
<< " ; where current threshold is: " << this->diag_thr << " . "
650-
<< std::endl;
655+
<< " ; where current threshold is: " << this->diag_thr << " . " << std::endl;
651656
// reset avg_iter
652657
DiagoIterAssist<T, Device>::avg_iter = 0.0;
653658
}

source/module_hsolver/hsolver_pw.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ class HSolverPW
5959
std::vector<Real>& pre_condition,
6060
Real* eigenvalue);
6161

62-
// psi initializer && change k point in psi
63-
void updatePsiK(hamilt::Hamilt<T, Device>* pHamilt, psi::Psi<T, Device>& psi, const int ik);
64-
6562
// calculate the precondition array for diagonalization in PW base
6663
void update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0);
6764

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "hsolver_pw_sdft.h"
22

33
#include "module_base/global_function.h"
4+
#include "module_base/parallel_device.h"
45
#include "module_base/timer.h"
56
#include "module_base/tool_title.h"
6-
#include "module_base/parallel_device.h"
77
#include "module_elecstate/module_charge/symmetry_rho.h"
88

99
#include <algorithm>
@@ -28,6 +28,30 @@ void HSolverPW_SDFT<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
2828
const int nbands = psi.get_nbands();
2929
const int nks = psi.get_nk();
3030

31+
//---------------------------------------------------------------------------------------------------------------
32+
//---------------------------------for psi init guess!!!!--------------------------------------------------------
33+
//---------------------------------------------------------------------------------------------------------------
34+
if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw")
35+
{
36+
for (int ik = 0; ik < nks; ++ik)
37+
{
38+
/// update H(k) for each k point
39+
pHamilt->updateHk(ik);
40+
41+
if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
42+
{
43+
/// update psi pointer for each k point
44+
psi.fix_k(ik);
45+
46+
/// for psi init guess!!!!
47+
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt);
48+
}
49+
}
50+
}
51+
//---------------------------------------------------------------------------------------------------------------
52+
//---------------------------------------------------------------------------------------------------------------
53+
//---------------------------------------------------------------------------------------------------------------
54+
3155
// prepare for the precondition of diagonalization
3256
std::vector<double> precondition(psi.get_nbasis(), 0.0);
3357

@@ -44,8 +68,9 @@ void HSolverPW_SDFT<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
4468
pHamilt->updateHk(ik);
4569
if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
4670
{
47-
this->updatePsiK(pHamilt, psi, ik);
48-
// template add precondition calculating here
71+
/// update psi pointer for each k point
72+
psi.fix_k(ik);
73+
/// template add precondition calculating here
4974
this->update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0());
5075
/// solve eigenvector and eigenvalue for H(k)
5176
double* p_eigenvalues = &(pes->ekb(ik, 0));
@@ -105,7 +130,7 @@ void HSolverPW_SDFT<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
105130
{
106131
for (int is = 0; is < this->nspin; is++)
107132
{
108-
setmem_var_op()(this->ctx, pes_pw->rho[is], 0, pes_pw->charge->nrxx);
133+
setmem_var_op()(this->ctx, pes_pw->rho[is], 0, pes_pw->charge->nrxx);
109134
}
110135
}
111136
// calculate stochastic rho

0 commit comments

Comments
 (0)