Skip to content

Commit cbc8dc0

Browse files
committed
refactor: add SDFT for ESolver_SDFT_PW
move SDFT program in sdf-version ABACUS to ESolver_SDFT_PW Save on schedule and no tests for SDFT are set up.
1 parent 38b3375 commit cbc8dc0

29 files changed

+1190
-1574
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ option(USE_CUDA "Enable support to CUDA." OFF)
1717
option(USE_ROCM "Enable support to ROCm." OFF)
1818
option(USE_OPENMP " Enable OpenMP in abacus." ON)
1919
option(ENABLE_ASAN "Enable AddressSanitizer" OFF)
20-
option(BUILD_TESTING "Build ABACUS unit tests" OFF)
20+
option(BUILD_TESTING "Build ABACUS unit tests" ON)
2121
option(GENERATE_TEST_REPORTS "Enable test report generation" OFF)
2222

2323
set(ABACUS_BIN_NAME abacus)

source/Makefile.Objects

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ soc.o\
5757
to_wannier90.o \
5858
unk_overlap_pw.o \
5959
berryphase.o \
60-
sto_elec.o\
61-
sto_wf.o\
6260
sto_iter.o\
61+
sto_wf.o\
6362
sto_hchi.o\
6463
sto_che.o\
6564

@@ -226,6 +225,7 @@ parallel_reduce.o\
226225
parallel_pw.o\
227226
ft.o\
228227
parallel_grid.o\
228+
parallel_stochi.o\
229229

230230
OBJS_ESOLVER=esolver.o\
231231
esolver_ks.o\

source/Makefile.vars

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ CPLUSPLUS_MPI = mpiicpc
44

55
LAPACK_DIR = $(MKLROOT)
66

7-
FFTW_DIR = /public/software/fftw_3.3.8
7+
FFTW_DIR = /home/qianrui/intelcompile/fftw_3.3.8
88

9-
ELPA_DIR = /public/software/elpa_21.05.002
9+
ELPA_DIR = /home/qianrui/intelcompile/elpa_21.05.002
1010
ELPA_INCLUDE_DIR = ${ELPA_DIR}/include/elpa-2021.05.002
1111

12-
CEREAL_DIR = /public/software/cereal
12+
CEREAL_DIR = /home/qianrui/headfile/cereal
1313

1414
# LIBXC_DIR = /public/software/libxc-5.0.0
1515

source/input.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ class Input
6464
// Stochastic DFT
6565
//==========================================================
6666
int nche_sto; // number of orders for Chebyshev expansion in stochastic DFT //qinarui 2021-2-5
67+
int nbands_sto; // number of stochastic bands //qianrui 2021-2-5
6768
int seed_sto; // random seed for sDFT
6869
double emax_sto; // Emax & Emin to normalize H
6970
double emin_sto;
7071
std::string stotype;
71-
int nbands_sto; // number of stochastic bands //qianrui 2021-2-5
7272

7373
//==========================================================
7474
// electrons / spin

source/module_base/global_variable.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,14 @@ int out_mul = 0; // qifeng add 2019/9/10
101101
//----------------------------------------------------------
102102
int NPROC = 1;
103103
int KPAR = 1;
104+
int NSTOGROUP = 1;
104105
int MY_RANK = 0;
105106
int MY_POOL = 0;
107+
int MY_STOGROUP = 0;
106108
int NPROC_IN_POOL = 1;
109+
int NPROC_IN_STOGROUP = 1;
107110
int RANK_IN_POOL = 0;
111+
int RANK_IN_STOGROUP = 0;
108112
int DRANK = -1; // mohan add 2012-01-13, must be -1, so we can recognize who didn't in DIAG_WORLD
109113
int DSIZE = KPAR;
110114
int DCOLOR = -1;

source/module_base/global_variable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,14 @@ extern int out_mul; // qifeng add 2019/9/10
116116
//========================================================================
117117
extern int NPROC;
118118
extern int KPAR;
119+
extern int NSTOGROUP;
119120
extern int MY_RANK;
120121
extern int MY_POOL;
122+
extern int MY_STOGROUP;
121123
extern int NPROC_IN_POOL;
124+
extern int NPROC_IN_STOGROUP;
122125
extern int RANK_IN_POOL;
126+
extern int RANK_IN_STOGROUP;
123127
extern int DRANK;
124128
extern int DSIZE;
125129
extern int DCOLOR;

source/module_cell/unitcell_pseudo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class UnitCell_pseudo : public UnitCell
3131
int lmax_ppwf;
3232
int lmaxmax; // liuyu 2021-07-04
3333
bool init_vel; // liuyu 2021-07-15
34-
//double nelec;
34+
double nelec;
3535

3636
public: // member functions
3737
UnitCell_pseudo();

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@ void ESolver_KS_PW:: eachiterinit(const int iter)
211211

212212
//(2) save change density as previous charge,
213213
// prepared fox mixing.
214-
GlobalC::CHR.save_rho_before_sum_band();
214+
if(GlobalV::MY_STOGROUP == 0)
215+
{
216+
GlobalC::CHR.save_rho_before_sum_band();
217+
}
215218
}
216219

217220
//Temporary, it should be replaced by hsolver later.

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 195 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
#include "./esolver_sdft_pw.h"
2+
#include "time.h"
3+
#include <fstream>
4+
#include <algorithm>
5+
#include "../module_base/timer.h"
6+
7+
//-------------------Temporary------------------
8+
#include "../module_base/global_variable.h"
9+
#include "../src_pw/global.h"
10+
#include "../src_pw/symmetry_rho.h"
11+
//----------------------------------------------
212

313
namespace ModuleESolver
414
{
@@ -9,18 +19,22 @@ ESolver_SDFT_PW::ESolver_SDFT_PW()
919
basisname = "PW";
1020
}
1121

22+
ESolver_SDFT_PW::~ESolver_SDFT_PW()
23+
{
24+
}
25+
1226
void ESolver_SDFT_PW::Init(Input &inp, UnitCell_pseudo &cell)
1327
{
14-
// ESolver_KS_PW::Init(inp, cell);
15-
// STO_WF.alloc();
16-
// stoiter.alloc( wf.npwx );
28+
ESolver_KS_PW::Init(inp, cell);
29+
stowf.init(GlobalC::kv.nks);
30+
if(INPUT.nbands_sto != 0) Init_Sto_Orbitals(this->stowf, INPUT.seed_sto);
31+
else Init_Com_Orbitals(this->stowf, GlobalC::kv);
32+
stoiter.init(GlobalC::wf.npwx, this->stowf.nchip);
1733
}
1834

1935
void ESolver_SDFT_PW::beforescf()
2036
{
2137
ESolver_KS_PW::beforescf();
22-
// STO_WF.init();
23-
// stoiter.init( wf.npwx );
2438
// if(NITER==0)
2539
// {
2640
// int iter = 1;
@@ -62,11 +76,11 @@ void ESolver_SDFT_PW::beforescf()
6276
// if(kv.nks > 1) hm.hpw.init_k(ik);
6377
// stoiter.stoche.ndmin = wf.npw;
6478
// complex<double> * out, *pchi;
65-
// out = STO_WF.shchi[ik].c;
79+
// out = stowf.shchi[ik].c;
6680
// if(NBANDS > 0)
67-
// pchi = STO_WF.chiortho[ik].c;
81+
// pchi = stowf.chiortho[ik].c;
6882
// else
69-
// pchi = STO_WF.chi0[ik].c;
83+
// pchi = stowf.chi0[ik].c;
7084

7185
// stoiter.stoche.calfinalvec(stoiter.stohchi.hchi_reciprocal, pchi, out, stoiter.nchip[ik]);
7286
// }
@@ -76,21 +90,190 @@ void ESolver_SDFT_PW::beforescf()
7690

7791
void ESolver_SDFT_PW::eachiterfinish(int iter, bool conv_elec)
7892
{
79-
93+
//print_eigenvalue(GlobalV::ofs_running);
94+
GlobalC::en.calculate_etot();
8095
}
8196
void ESolver_SDFT_PW::afterscf(bool conv_elec)
8297
{
83-
98+
for(int is=0; is<GlobalV::NSPIN; is++)
99+
{
100+
std::stringstream ssc;
101+
std::stringstream ss1;
102+
ssc << GlobalV::global_out_dir << "SPIN" << is + 1 << "_CHG";
103+
ss1 << GlobalV::global_out_dir << "SPIN" << is + 1 << "_CHG.cube";
104+
GlobalC::CHR.write_rho(GlobalC::CHR.rho_save[is], is, 0, ssc.str() );//mohan add 2007-10-17
105+
GlobalC::CHR.write_rho_cube(GlobalC::CHR.rho_save[is], is, ss1.str(), 3);
106+
}
107+
if(conv_elec)
108+
{
109+
//GlobalV::ofs_running << " convergence is achieved" << std::endl;
110+
//GlobalV::ofs_running << " !FINAL_ETOT_IS " << GlobalC::en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
111+
GlobalV::ofs_running << "\n charge density convergence is achieved" << std::endl;
112+
GlobalV::ofs_running << " final etot is " << GlobalC::en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
113+
}
114+
else
115+
{
116+
GlobalV::ofs_running << " convergence has NOT been achieved!" << std::endl;
117+
}
84118
}
85119

86120
void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
87121
{
88-
122+
double *h_diag = new double[GlobalC::wf.npwx * GlobalV::NPOL];
123+
GlobalV::ofs_running << " " <<setw(8) << "K-point" << setw(15) << "CG iter num" << setw(15) << "Time(Sec)"<< std::endl;
124+
GlobalV::ofs_running << setprecision(6) << setiosflags(ios::fixed) << setiosflags(ios::showpoint);
125+
for (int ik = 0;ik < GlobalC::kv.nks;ik++)
126+
{
127+
if(GlobalV::NBANDS > 0 && GlobalV::MY_STOGROUP == 0)
128+
{
129+
this->c_bands_k(ik,h_diag,istep+1,iter);
130+
}
131+
else
132+
{
133+
GlobalC::hm.hpw.init_k(ik);
134+
//In fact, hm.hpw.init_k has been done in wf.wfcinit();
135+
}
136+
137+
#ifdef __MPI
138+
if(GlobalV::NBANDS > 0)
139+
{
140+
MPI_Bcast(GlobalC::wf.evc[ik].c, GlobalC::wf.npwx*GlobalV::NBANDS*2, MPI_DOUBLE , 0, PARAPW_WORLD);
141+
MPI_Bcast(GlobalC::wf.ekb[ik], GlobalV::NBANDS, MPI_DOUBLE, 0, PARAPW_WORLD);
142+
}
143+
#endif
144+
stoiter.stoche.ndmin = GlobalC::wf.npw;
145+
stoiter.orthog(ik,this->stowf);
146+
stoiter.checkemm(ik,iter,this->stowf); //check and reset emax & emin
147+
}
148+
for (int ik = 0;ik < GlobalC::kv.nks;ik++)
149+
{
150+
//init k
151+
if(GlobalC::kv.nks > 1) GlobalC::hm.hpw.init_k(ik);
152+
stoiter.stoche.ndmin = GlobalC::wf.npw;
153+
154+
stoiter.sumpolyval_k(ik, this->stowf);
155+
}
156+
delete [] h_diag;
157+
GlobalC::en.eband = 0.0;
158+
GlobalC::en.demet = 0.0;
159+
GlobalC::en.ef = 0.0;
160+
GlobalC::en.ef_up = 0.0;
161+
GlobalC::en.ef_dw = 0.0;
162+
stoiter.itermu(iter);
163+
//(5) calculate new charge density
164+
// calculate KS rho.
165+
if(GlobalV::NBANDS > 0)
166+
{
167+
if(GlobalV::MY_STOGROUP == 0)
168+
{
169+
GlobalC::CHR.sum_band();
170+
}
171+
else
172+
{
173+
for(int is=0; is < GlobalV::NSPIN; is++)
174+
{
175+
ModuleBase::GlobalFunc::ZEROS(GlobalC::CHR.rho[is], GlobalC::pw.nrxx);
176+
}
177+
}
178+
MPI_Bcast(&GlobalC::en.eband,1, MPI_DOUBLE, 0,PARAPW_WORLD);
179+
}
180+
else
181+
{
182+
for(int is=0; is < GlobalV::NSPIN; is++)
183+
{
184+
ModuleBase::GlobalFunc::ZEROS(GlobalC::CHR.rho[is], GlobalC::pw.nrxx);
185+
}
186+
}
187+
// calculate stochastic rho
188+
stoiter.sum_stoband(this->stowf);
189+
190+
191+
//(6) calculate the delta_harris energy
192+
// according to new charge density.
193+
// mohan add 2009-01-23
194+
//en.calculate_harris(2);
195+
196+
if(GlobalV::MY_STOGROUP==0)
197+
{
198+
Symmetry_rho srho;
199+
for(int is=0; is < GlobalV::NSPIN; is++)
200+
{
201+
srho.begin(is, GlobalC::CHR,GlobalC::pw, GlobalC::Pgrid, GlobalC::symm);
202+
}
203+
}
204+
else
205+
{
206+
if(ModuleSymmetry::Symmetry::symm_flag) MPI_Barrier(MPI_COMM_WORLD);
207+
}
89208
}
90209

91-
void ESolver_SDFT_PW::cal_Energy(energy &en)
210+
void ESolver_SDFT_PW:: c_bands_k(const int ik, double* h_diag, const int istep, const int iter)
92211
{
212+
ModuleBase::timer::tick(this->classname,"c_bands_k");
213+
int precondition_type = 2;
214+
GlobalC::hm.hpw.init_k(ik);
215+
216+
//===========================================
217+
// Conjugate-Gradient diagonalization
218+
// h_diag is the precondition matrix
219+
// h_diag(1:npw) = MAX( 1.0, g2kin(1:npw) );
220+
//===========================================
221+
if (precondition_type==1)
222+
{
223+
for (int ig = 0;ig < GlobalC::wf.npw; ++ig)
224+
{
225+
h_diag[ig] = std::max(1.0, GlobalC::wf.g2kin[ig]);
226+
if(GlobalV::NPOL==2) h_diag[ig+GlobalC::wf.npwx] = h_diag[ig];
227+
}
228+
}
229+
else if (precondition_type==2)
230+
{
231+
for (int ig = 0;ig < GlobalC::wf.npw; ig++)
232+
{
233+
h_diag[ig] = 1 + GlobalC::wf.g2kin[ig] + sqrt( 1 + (GlobalC::wf.g2kin[ig] - 1) * (GlobalC::wf.g2kin[ig] - 1));
234+
if(GlobalV::NPOL==2) h_diag[ig+GlobalC::wf.npwx] = h_diag[ig];
235+
}
236+
}
237+
//h_diag can't be zero! //zhengdy-soc
238+
if(GlobalV::NPOL==2)
239+
{
240+
for(int ig = GlobalC::wf.npw;ig < GlobalC::wf.npwx; ig++)
241+
{
242+
h_diag[ig] = 1.0;
243+
h_diag[ig+ GlobalC::wf.npwx] = 1.0;
244+
}
245+
}
246+
clock_t start=clock();
247+
248+
//============================================================
249+
// diago the hamiltonian!!
250+
// In plane wave method, firstly using cinitcgg to diagnolize,
251+
// then using cg method.
252+
//
253+
// In localized orbital presented in plane wave case,
254+
// only using cinitcgg.
255+
//
256+
// In linear scaling method, using sparse matrix and
257+
// adjacent searching code and cg method to calculate the
258+
// eigenstates.
259+
//=============================================================
260+
double avg_iter_k = 0.0;
261+
GlobalC::hm.diagH_pw(istep, iter, ik, h_diag, avg_iter_k);
262+
263+
GlobalC::en.print_band(ik);
264+
clock_t finish=clock();
265+
const double duration = static_cast<double>(finish - start) / CLOCKS_PER_SEC;
266+
GlobalV::ofs_running << " " << setw(8)
267+
<< ik+1 << setw(15)
268+
<< avg_iter_k << setw(15) << duration << endl;
269+
270+
ModuleBase::timer::tick(this->classname,"c_bands_k");
271+
}
93272

273+
274+
void ESolver_SDFT_PW::cal_Energy(energy &en)
275+
{
276+
94277
}
95278

96279
void ESolver_SDFT_PW::cal_Force(ModuleBase::matrix &force)

source/module_esolver/esolver_sdft_pw.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,36 @@
11
#include "./esolver_ks_pw.h"
2+
#include "../src_pw/sto_wf.h"
3+
#include "../src_pw/sto_iter.h"
4+
#include "../src_pw/sto_che.h"
5+
#include "../src_pw/sto_hchi.h"
6+
27
namespace ModuleESolver
38
{
49

510
class ESolver_SDFT_PW: public ESolver_KS_PW
611
{
712
public:
813
ESolver_SDFT_PW();
14+
~ESolver_SDFT_PW();
915
void Init(Input &inp, UnitCell_pseudo &cell) override;
1016
void cal_Energy(energy& en) override;
11-
void cal_Force(ModuleBase::matrix &force) override;
12-
void cal_Stress(ModuleBase::matrix &stress) override;
17+
void cal_Force(ModuleBase::matrix& force) override;
18+
void cal_Stress(ModuleBase::matrix& stress) override;
19+
public:
20+
Stochastic_WF stowf;
21+
Stochastic_Iter stoiter;
22+
// Stochastic_Chebychev stoche;
23+
// Stochastic_hchi stohchi;
1324

1425
protected:
1526
virtual void beforescf() override;
1627
// virtual void eachiterinit(int iter) override;
1728
virtual void hamilt2density(const int istep, const int iter, const double ethr) override;
1829
virtual void eachiterfinish(const int iter, const bool conv) override;
1930
virtual void afterscf(const bool) override;
31+
private:
32+
void c_bands_k(const int ik, double* h_diag, const int istep, const int iter);
33+
2034
};
2135

2236
}

0 commit comments

Comments
 (0)