Skip to content

Commit 8094ec8

Browse files
committed
Refactor: add orbital parallel algorithm for sdft
1 parent 804de46 commit 8094ec8

22 files changed

+203
-181
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ option(USE_CUDA "Enable support to CUDA." OFF)
1717
option(USE_ROCM "Enable support to ROCm." OFF)
1818
option(USE_OPENMP " Enable OpenMP in abacus." ON)
1919
option(ENABLE_ASAN "Enable AddressSanitizer" OFF)
20-
option(BUILD_TESTING "Build ABACUS unit tests" ON)
20+
option(BUILD_TESTING "Build ABACUS unit tests" OFF)
2121
option(GENERATE_TEST_REPORTS "Enable test report generation" OFF)
2222

2323
set(ABACUS_BIN_NAME abacus)

source/Makefile.Objects

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ parallel_reduce.o\
223223
parallel_pw.o\
224224
ft.o\
225225
parallel_grid.o\
226-
parallel_stochi.o\
227226

228227
OBJS_ESOLVER=esolver.o\
229228
esolver_ks.o\

source/driver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ void Driver::atomic_world(void)
9696
ModuleESolver::ESolver *p_esolver;
9797
if(GlobalV::BASIS_TYPE=="pw" || GlobalV::BASIS_TYPE=="lcao_in_pw")
9898
{
99-
use_ensol = "ksdft_pw";
99+
if(GlobalV::CALCULATION.substr(0,3) == "sto") use_ensol = "sdft_pw";
100+
else use_ensol = "ksdft_pw";
100101
//We set it temporarily
101102
//Finally, we have ksdft_pw, ksdft_lcao, sdft_pw, ofdft, lj, eam, etc.
102103
ModuleESolver::init_esolver(p_esolver, use_ensol);

source/input.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ void Input::Default(void)
140140
pw_seed = 1;
141141
nche_sto = 0;
142142
seed_sto = 0;
143-
stotype = "pw";
144143
kpar = 1;
145144
berry_phase = false;
146145
gdir = 3;
@@ -552,9 +551,9 @@ bool Input::Read(const std::string &fn)
552551
{
553552
read_value(ifs, emin_sto);
554553
}
555-
else if (strcmp("stotype", word) == 0)
554+
else if (strcmp("nstogroup", word) == 0)
556555
{
557-
read_value(ifs, stotype);
556+
read_value(ifs, nstogroup);
558557
}
559558
else if (strcmp("kpar", word) == 0) // number of pools
560559
{
@@ -1798,6 +1797,7 @@ void Input::Default_2(void) // jiyy add 2019-08-04
17981797
vdw_radius = "95";
17991798
}
18001799
}
1800+
if(calculation.substr(0,3) != "sto") nstogroup = 1;
18011801
}
18021802
#ifdef __MPI
18031803
void Input::Bcast()
@@ -1828,7 +1828,7 @@ void Input::Bcast()
18281828
Parallel_Common::bcast_int(pw_seed);
18291829
Parallel_Common::bcast_double(emax_sto);
18301830
Parallel_Common::bcast_double(emin_sto);
1831-
Parallel_Common::bcast_string(stotype);
1831+
Parallel_Common::bcast_int(nstogroup);
18321832
Parallel_Common::bcast_int(kpar);
18331833
Parallel_Common::bcast_bool(berry_phase);
18341834
Parallel_Common::bcast_int(gdir);
@@ -2174,7 +2174,7 @@ void Input::Check(void)
21742174
*/
21752175
this->relax_nmax = 1;
21762176
}
2177-
else if (calculation == "scf-sto") // qianrui 2021-2-20
2177+
else if (calculation == "sto-scf") // qianrui 2021-2-20
21782178
{
21792179
if (mem_saver == 1)
21802180
{

source/input.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class Input
6868
int seed_sto; // random seed for sDFT
6969
double emax_sto; // Emax & Emin to normalize H
7070
double emin_sto;
71-
std::string stotype;
71+
int nstogroup;
7272

7373
//==========================================================
7474
// electrons / spin

source/input_conv.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ void Input_Conv::Convert(void)
6161
GlobalV::KPAR = temp_nproc;
6262
#else
6363
GlobalV::KPAR = INPUT.kpar;
64+
GlobalV::NSTOGROUP = INPUT.nstogroup;
6465
#endif
6566
GlobalV::CALCULATION = INPUT.calculation;
6667

source/module_elecstate/elecstate_pw.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,19 @@ void ElecStatePW::parallelK()
4141
{
4242
#ifdef __MPI
4343
charge->rho_mpi();
44-
if (GlobalV::CALCULATION != "scf-sto" && GlobalV::CALCULATION != "relax-sto"
45-
&& GlobalV::CALCULATION != "md-sto") // qinarui add it temporarily.
46-
{
47-
//==================================
48-
// Reduce all the Energy in each cpu
49-
//==================================
50-
this->eband /= GlobalV::NPROC_IN_POOL;
51-
Parallel_Reduce::reduce_double_all(this->eband);
52-
}
44+
if(GlobalV::CALCULATION.substr(0,3) == "sto") //qinarui add it 2021-7-21
45+
{
46+
GlobalC::en.eband /= GlobalV::NPROC_IN_POOL;
47+
MPI_Allreduce(MPI_IN_PLACE, &GlobalC::en.eband, 1, MPI_DOUBLE, MPI_SUM , STO_WORLD);
48+
}
49+
else
50+
{
51+
//==================================
52+
// Reduce all the Energy in each cpu
53+
//==================================
54+
GlobalC::en.eband /= GlobalV::NPROC_IN_POOL;
55+
Parallel_Reduce::reduce_double_all( GlobalC::en.eband );
56+
}
5357
#endif
5458
return;
5559
}

source/module_esolver/esolver_ks.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,17 @@ void ESolver_KS:: Run(const int istep, UnitCell_pseudo& cell)
5353
clock_t iterstart,iterend;
5454
iterstart = std::clock();
5555
set_ethr(istep,iter);
56-
eachiterinit(iter);
56+
eachiterinit(iter);
5757

5858
this->hamilt2density(istep, iter, this->diag_ethr);
5959

60+
//<Temporary> It may be changed when more clever parallel algorithm is put forward.
61+
//When parallel algorithm for bands are adopted. Density will only be treated in the first group.
62+
//(Different ranks should have abtained the same, but small differences always exist in practice.)
63+
//Maybe in the future, density and wavefunctions should use different parallel algorithms, in which
64+
//they do not occupy all processors, for example wavefunctions uses 20 processors while density uses 10.
65+
if(GlobalV::MY_STOGROUP == 0)
66+
{
6067
// double drho = this->estate.caldr2();
6168
// EState should be used after it is constructed.
6269
drho = GlobalC::CHR.get_drho();
@@ -88,6 +95,13 @@ void ESolver_KS:: Run(const int istep, UnitCell_pseudo& cell)
8895
//conv_elec = this->estate.mix_rho();
8996
GlobalC::CHR.mix_rho(iter);
9097
}
98+
99+
}
100+
#ifdef __MPI
101+
MPI_Bcast(&drho, 1, MPI_DOUBLE , 0, PARAPW_WORLD);
102+
MPI_Bcast(&conv_elec, 1, MPI_DOUBLE , 0, PARAPW_WORLD);
103+
MPI_Bcast(GlobalC::CHR.rho[0], GlobalC::pw.nrxx, MPI_DOUBLE, 0, PARAPW_WORLD);
104+
#endif
91105

92106
// Hamilt should be used after it is constructed.
93107
// this->phamilt->update(conv_elec);

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ void ESolver_KS_PW::Init(Input &inp, UnitCell_pseudo &ucell)
151151
//================================
152152
// Initial start wave functions
153153
//================================
154-
if (GlobalV::NBANDS != 0 || (GlobalV::CALCULATION != "scf-sto" && GlobalV::CALCULATION != "relax-sto" && GlobalV::CALCULATION != "md-sto")) //qianrui add
154+
if (GlobalV::NBANDS != 0 || GlobalV::CALCULATION.substr(0,3) != "sto")
155+
// qianrui add temporarily. In the future, wfcinit() should be compatible with cases when NBANDS=0
155156
{
156157
GlobalC::wf.wfcinit();
157158
}

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
205205
{
206206
if(ModuleSymmetry::Symmetry::symm_flag) MPI_Barrier(MPI_COMM_WORLD);
207207
}
208+
209+
if(GlobalV::MY_STOGROUP == 0)
210+
{
211+
GlobalC::en.deband = GlobalC::en.delta_e();
212+
}
213+
208214
}
209215

210216
void ESolver_SDFT_PW:: c_bands_k(const int ik, double* h_diag, const int istep, const int iter)

0 commit comments

Comments
 (0)