Skip to content

Commit 8262eb6

Browse files
committed
Refactor: using HSolverPW, Psi, ElecStatePW, HamiltPW in ESolver_KS_P, evc should be deleted later
1 parent ffa4433 commit 8262eb6

24 files changed

+236
-56
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ target_link_libraries(${ABACUS_BIN_NAME}
276276
ri
277277
driver
278278
xc
279+
hsolver
280+
elecstate
281+
hamilt
282+
psi
279283
esolver
280284
-lm
281285
)

source/module_elecstate/elecstate.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace elecstate
1111
class ElecState
1212
{
1313
public:
14+
ElecState(){};
1415
virtual void init(Charge *chg_in, // pointer for class Charge
1516
const K_Vectors *klist_in,
1617
int nk_in, // number of k points
@@ -68,6 +69,8 @@ class ElecState
6869
// occupation weight for each k-point and band
6970
ModuleBase::matrix wg;
7071

72+
std::string classname = "none";
73+
7174
protected:
7275
// calculate ebands for all k points and all occupied bands
7376
void calEBand();

source/module_elecstate/elecstate_lcao.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ElecStateLCAO : public ElecState
2424
this->loc = loc_in;
2525
this->uhm = uhm_in;
2626
this->lowf = lowf_in;
27+
this->classname = "ElecStateLCAO";
2728
}
2829
// void init(Charge* chg_in):charge(chg_in){} override;
2930

source/module_elecstate/elecstate_pw.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,18 @@ namespace elecstate
1010
void ElecStatePW::psiToRho(const psi::Psi<std::complex<double>>& psi)
1111
{
1212
this->calculate_weights();
13+
1314
this->calEBand();
15+
16+
for(int is=0; is<GlobalV::NSPIN; is++)
17+
{
18+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
19+
if (XC_Functional::get_func_type() == 3)
20+
{
21+
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
22+
}
23+
}
24+
1425
for (int ik = 0; ik < psi.get_nk(); ++ik)
1526
{
1627
psi.fix_k(ik);
@@ -126,6 +137,7 @@ void ElecStatePW::rhoBandK(const psi::Psi<std::complex<double>>& psi)
126137
}
127138
}
128139
else
140+
{
129141
for (int ibnd = 0; ibnd < nbands; ibnd++)
130142
{
131143
///
@@ -171,6 +183,7 @@ void ElecStatePW::rhoBandK(const psi::Psi<std::complex<double>>& psi)
171183
}
172184
}
173185
}
186+
}
174187

175188
return;
176189
}

source/module_elecstate/elecstate_pw.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class ElecStatePW : public ElecState
1313
ElecStatePW(const PW_Basis* basis_in, Charge* chg_in, int nbands_in) : basis(basis_in)
1414
{
1515
init(chg_in, basis_in->Klist, basis_in->Klist->nks, nbands_in);
16+
this->classname = "ElecStatePW";
1617
}
1718
// void init(Charge* chg_in):charge(chg_in){} override;
1819

source/module_esolver/esolver_ks.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include "./esolver_fp.h"
44
#include "string.h"
55
#include "fstream"
6+
#include "module_hsolver/hsolver.h"
7+
#include "module_hamilt/hamilt.h"
8+
#include "module_elecstate/elecstate.h"
69
// #include "estates.h"
710
// #include "h2e.h"
811
namespace ModuleESolver
@@ -56,6 +59,9 @@ class ESolver_KS: public ESolver_FP
5659
void reset_diagethr(std::ofstream &ofs_running, const double hsover_error);
5760

5861

62+
hsolver::HSolver* phsol = nullptr;
63+
elecstate::ElecState* pelec = nullptr;
64+
hamilt::Hamilt* phami = nullptr;
5965

6066
protected:
6167
std::string basisname; //PW or LCAO

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
//-----stress------------------
2727
#include "../src_pw/stress_pw.h"
2828
//---------------------------------------------------
29+
#include "module_hsolver/hsolver_pw.h"
30+
#include "module_elecstate/elecstate_pw.h"
31+
#include "module_hamilt/hamilt_pw.h"
32+
#include "module_hsolver/diago_iter_assist.h"
2933

3034
namespace ModuleESolver
3135
{
@@ -190,6 +194,41 @@ void ESolver_KS_PW:: beforescf()
190194
{
191195
srho.begin(is, GlobalC::CHR,GlobalC::pw, GlobalC::Pgrid, GlobalC::symm);
192196
}
197+
//init Psi, HSolver, ElecState, Hamilt
198+
hsolver::DiagoIterAssist::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX;
199+
hsolver::DiagoIterAssist::PW_DIAG_THR = GlobalV::PW_DIAG_THR;
200+
const PW_Basis* pbas = &(GlobalC::pw);
201+
if(this->phsol == nullptr)
202+
{
203+
this->phsol = new hsolver::HSolverPW(pbas);
204+
}
205+
else if(this->phsol->classname != "HSolverPW")
206+
{
207+
delete[] this->phsol;
208+
this->phsol = new hsolver::HSolverPW(pbas);
209+
}
210+
this->phsol->method = GlobalV::KS_SOLVER;
211+
if(this->pelec == nullptr)
212+
{
213+
this->pelec = new elecstate::ElecStatePW( pbas, (Charge*)(&(GlobalC::CHR)), GlobalV::NBANDS);
214+
}
215+
else if(this->pelec->classname != "ElecStatePW")
216+
{
217+
delete[] this->pelec;
218+
this->pelec = new elecstate::ElecStatePW( pbas, (Charge*)(&(GlobalC::CHR)), GlobalV::NBANDS);
219+
}
220+
Hamilt_PW* hpw = &(GlobalC::hm.hpw);
221+
if(this->phami == nullptr)
222+
{
223+
this->phami = new hamilt::HamiltPW(hpw);
224+
}
225+
else if(this->phami->classname != "HamiltPW")
226+
{
227+
delete[] this->phami;
228+
this->phami = new hamilt::HamiltPW(hpw);
229+
}
230+
//initial psi
231+
GlobalC::wf.evc_transform_psi();
193232
}
194233

195234
void ESolver_KS_PW:: eachiterinit(const int iter)
@@ -217,22 +256,37 @@ void ESolver_KS_PW:: eachiterinit(const int iter)
217256
//Temporary, it should be replaced by hsolver later.
218257
void ESolver_KS_PW:: hamilt2density(const int istep, const int iter, const double ethr)
219258
{
220-
GlobalV::PW_DIAG_THR = ethr;
221-
this->c_bands(istep,iter);
222-
223-
GlobalC::en.eband = 0.0;
224-
GlobalC::en.demet = 0.0;
225-
GlobalC::en.ef = 0.0;
226-
GlobalC::en.ef_up = 0.0;
227-
GlobalC::en.ef_dw = 0.0;
259+
if(this->phsol != nullptr)
260+
{
261+
// reset energy
262+
this->pelec->eband = 0.0;
263+
this->pelec->demet = 0.0;
264+
this->pelec->ef = 0.0;
265+
GlobalC::en.ef_up = 0.0;
266+
GlobalC::en.ef_dw = 0.0;
267+
// choose if psi should be diag in subspace
268+
// be careful that istep start from 0 and iter start from 1
269+
if((istep==0||istep==1)&&iter==1)
270+
{
271+
hsolver::DiagoIterAssist::need_subspace = false;
272+
}
273+
else
274+
{
275+
hsolver::DiagoIterAssist::need_subspace = true;
276+
}
228277

229-
// calculate weights of each band.
230-
Occupy::calculate_weights();
278+
hsolver::DiagoIterAssist::PW_DIAG_THR = ethr;
279+
this->phsol->solve(this->phami, GlobalC::wf.psi[0], this->pelec);
231280

232-
// calculate new charge density according to
233-
// new wave functions.
234-
// calculate the new eband here.
235-
GlobalC::CHR.sum_band();
281+
// transform energy for print
282+
GlobalC::en.eband = this->pelec->eband;
283+
GlobalC::en.demet = this->pelec->demet;
284+
GlobalC::en.ef = this->pelec->ef;
285+
}
286+
else
287+
{
288+
ModuleBase::WARNING_QUIT("ESolver_KS_PW", "HSolver has not been initialed!");
289+
}
236290

237291
// add exx
238292
#ifdef __LCAO
@@ -350,6 +404,17 @@ void ESolver_KS_PW:: eachiterfinish(const int iter, const bool conv_elec)
350404

351405
void ESolver_KS_PW::afterscf(const bool conv_elec)
352406
{
407+
//temporary transform psi to evc
408+
// psi back to evc
409+
GlobalC::wf.psi_transform_evc();
410+
for(int ik=0; ik<this->pelec->ekb.nr; ++ik)
411+
{
412+
for(int ib=0; ib<this->pelec->ekb.nc; ++ib)
413+
{
414+
GlobalC::wf.ekb[ik][ib] = this->pelec->ekb(ik, ib);
415+
GlobalC::wf.wg(ik, ib) = this->pelec->wg(ik, ib);
416+
}
417+
}
353418
#ifdef __LCAO
354419
if(GlobalC::chi0_hilbert.epsilon) // pengfei 2016-11-23
355420
{

source/module_hamilt/hamilt.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class Hamilt
2525
virtual void matrix(MatrixBlock<std::complex<double>> &hk_in, MatrixBlock<std::complex<double>> &sk_in){return;}
2626
virtual void matrix(MatrixBlock<double> &hk_in, MatrixBlock<double> &sk_in){return;}
2727

28+
std::string classname = "none";
29+
2830
protected:
2931
// array, save operations from each operators
3032
// would be implemented later

source/module_hamilt/hamilt_lcao.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ template <typename T, typename T1> class HamiltLCAO : public Hamilt
4444
this->GK = GK_in;
4545
this->genH = genH_in;
4646
this->LM = LM_in;
47+
this->classname = "HamiltLCAO";
4748
}
4849
//~HamiltLCAO();
4950

source/module_hamilt/hamilt_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ void HamiltPW::ch_mock()
88
return;
99
}
1010

11-
void HamiltPW::hk_mock()
11+
void HamiltPW::hk_mock(const int ik)
1212
{
13+
this->hpw->init_k(ik);
1314
return;
1415
}
1516

0 commit comments

Comments
 (0)