Skip to content

Commit 3fec10a

Browse files
authored
Refacor: change the order of k-point parallel and band parallel. (#5692)
* fix stuck in out_chg * fix DCU low efficiency * refactor: change the order of kpar and bndpar ABACUS first divide nprocs to kpar group and then divide this groups to bndpar subgroups * fix bug * fix bug2
1 parent f997e86 commit 3fec10a

File tree

9 files changed

+139
-103
lines changed

9 files changed

+139
-103
lines changed

source/module_base/global_variable.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,9 @@ extern int GSIZE;
4747
extern int KPAR_LCAO;
4848

4949
//==========================================================
50-
// EXPLAIN : readin file dir, output file std::ofstream
51-
// GLOBAL VARIABLES :
52-
// NAME : global_in_card
53-
// NAME : stru_file
54-
// NAME : global_kpoint_card
55-
// NAME : global_wannier_card
56-
// NAME : global_pseudo_dir
57-
// NAME : global_pseudo_type // mohan add 2013-05-20 (xiaohui add 2013-06-23)
58-
// NAME : global_out_dir
5950
// NAME : ofs_running( contain information during runnnig)
6051
// NAME : ofs_warning( contain warning information, including error)
6152
//==========================================================
62-
// extern std::string global_pseudo_type; // mohan add 2013-05-20 (xiaohui add
63-
// 2013-06-23)
6453
extern std::ofstream ofs_running;
6554
extern std::ofstream ofs_warning;
6655
extern std::ofstream ofs_info;
Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,55 @@
11
#if defined __MPI
22

33
#include "mpi.h"
4+
#include "parallel_global.h"
45

56
MPI_Comm POOL_WORLD;
6-
MPI_Comm INTER_POOL = MPI_COMM_NULL; // communicator among different pools
7+
MPI_Comm INTER_POOL; // communicator among different pools
78
MPI_Comm STO_WORLD;
89
MPI_Comm PARAPW_WORLD;
910
MPI_Comm GRID_WORLD; // mohan add 2012-01-13
1011
MPI_Comm DIAG_WORLD; // mohan add 2012-01-13
1112

13+
MPICommGroup::MPICommGroup(MPI_Comm parent_comm)
14+
: parent_comm(parent_comm)
15+
{
16+
MPI_Comm_size(parent_comm, &this->gsize);
17+
MPI_Comm_rank(parent_comm, &this->grank);
18+
}
19+
20+
MPICommGroup::~MPICommGroup()
21+
{
22+
if (group_comm != MPI_COMM_NULL)
23+
{
24+
MPI_Comm_free(&group_comm);
25+
}
26+
if (inter_comm != MPI_COMM_NULL)
27+
{
28+
MPI_Comm_free(&inter_comm);
29+
}
30+
}
31+
32+
void MPICommGroup::divide_group_comm(const int& ngroup, const bool assert_even)
33+
{
34+
this->ngroups = ngroup;
35+
Parallel_Global::divide_mpi_groups(this->gsize,
36+
ngroup,
37+
this->grank,
38+
this->nprocs_in_group,
39+
this->my_group,
40+
this->rank_in_group,
41+
assert_even);
42+
43+
MPI_Comm_split(parent_comm, my_group, rank_in_group, &group_comm);
44+
if(this->gsize % ngroup == 0)
45+
{
46+
this->is_even = true;
47+
}
48+
49+
if (this->is_even)
50+
{
51+
MPI_Comm_split(parent_comm, my_inter, rank_in_inter, &inter_comm);
52+
}
53+
}
54+
1255
#endif

source/module_base/parallel_comm.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define PARALLEL_COMM_H
33

44
#ifdef __MPI
5-
65
#include "mpi.h"
76
extern MPI_Comm POOL_WORLD;
87
extern MPI_Comm INTER_POOL; // communicator among different pools
@@ -11,6 +10,33 @@ extern MPI_Comm PARAPW_WORLD;
1110
extern MPI_Comm GRID_WORLD; // mohan add 2012-01-13
1211
extern MPI_Comm DIAG_WORLD; // mohan add 2012-01-13
1312

13+
14+
class MPICommGroup
15+
{
16+
public:
17+
MPICommGroup(MPI_Comm parent_comm);
18+
~MPICommGroup();
19+
void divide_group_comm(const int& ngroup, const bool assert_even = true);
20+
public:
21+
bool is_even = false; ///< whether the group is even
22+
23+
MPI_Comm parent_comm = MPI_COMM_NULL; ///< parent communicator
24+
int gsize = 0; ///< size of parent communicator
25+
int grank = 0; ///< rank of parent communicator
26+
27+
MPI_Comm group_comm = MPI_COMM_NULL; ///< group communicator
28+
int ngroups = 0; ///< number of groups
29+
int nprocs_in_group = 0; ///< number of processes in the group
30+
int my_group = 0; ///< the group index
31+
int rank_in_group = 0; ///< the rank in the group
32+
33+
MPI_Comm inter_comm = MPI_COMM_NULL; ///< inter communicator
34+
bool has_inter_comm = false; ///< whether has inter communicator
35+
int& nprocs_in_inter = ngroups; ///< number of processes in the inter communicator
36+
int& my_inter = rank_in_group; ///< the rank in the inter communicator
37+
int& rank_in_inter = my_group; ///< the inter group index
38+
};
39+
1440
#endif
1541

1642
#endif // PARALLEL_COMM_H

source/module_base/parallel_global.cpp

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ void Parallel_Global::finalize_mpi()
251251

252252
void Parallel_Global::init_pools(const int& NPROC,
253253
const int& MY_RANK,
254-
const int& NSTOGROUP,
254+
const int& BNDPAR,
255255
const int& KPAR,
256256
int& NPROC_IN_STOGROUP,
257257
int& RANK_IN_STOGROUP,
@@ -266,7 +266,7 @@ void Parallel_Global::init_pools(const int& NPROC,
266266
//----------------------------------------------------------
267267
Parallel_Global::divide_pools(NPROC,
268268
MY_RANK,
269-
NSTOGROUP,
269+
BNDPAR,
270270
KPAR,
271271
NPROC_IN_STOGROUP,
272272
RANK_IN_STOGROUP,
@@ -314,7 +314,7 @@ void Parallel_Global::init_pools(const int& NPROC,
314314
#ifdef __MPI
315315
void Parallel_Global::divide_pools(const int& NPROC,
316316
const int& MY_RANK,
317-
const int& NSTOGROUP,
317+
const int& BNDPAR,
318318
const int& KPAR,
319319
int& NPROC_IN_STOGROUP,
320320
int& RANK_IN_STOGROUP,
@@ -323,30 +323,55 @@ void Parallel_Global::divide_pools(const int& NPROC,
323323
int& RANK_IN_POOL,
324324
int& MY_POOL)
325325
{
326-
// Divide the global communicator into stogroups.
327-
divide_mpi_groups(NPROC, NSTOGROUP, MY_RANK, NPROC_IN_STOGROUP, MY_STOGROUP, RANK_IN_STOGROUP, true);
328-
329-
// (2) per process in each pool
330-
divide_mpi_groups(NPROC_IN_STOGROUP, KPAR, RANK_IN_STOGROUP, NPROC_IN_POOL, MY_POOL, RANK_IN_POOL);
331-
332-
int key = 1;
333-
MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, key, &STO_WORLD);
334-
335-
//========================================================
336-
// MPI_Comm_Split: Creates new communicators based on
337-
// colors(2nd parameter) and keys(3rd parameter)
338-
// Note: The color must be non-negative or MPI_UNDEFINED.
339-
//========================================================
340-
MPI_Comm_split(STO_WORLD, MY_POOL, key, &POOL_WORLD);
341-
342-
if (NPROC_IN_STOGROUP % KPAR == 0)
326+
// note: the order of k-point parallelization and band parallelization is important
327+
// The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL
328+
// and MY_STOGROUP will be the same as well.
329+
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
343330
{
344-
MPI_Comm_split(STO_WORLD, RANK_IN_POOL, key, &INTER_POOL);
331+
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
332+
<< BNDPAR * KPAR << ")." << std::endl;
333+
exit(1);
334+
}
335+
// k-point parallelization
336+
MPICommGroup kpar_group(MPI_COMM_WORLD);
337+
kpar_group.divide_group_comm(KPAR, false);
338+
339+
// band parallelization
340+
MPICommGroup bndpar_group(kpar_group.group_comm);
341+
bndpar_group.divide_group_comm(BNDPAR, true);
342+
343+
// Set parallel index.
344+
// In previous versions, the order of k-point parallelization and band parallelization is reversed.
345+
// So we need to keep some variables for compatibility.
346+
NPROC_IN_POOL = bndpar_group.nprocs_in_group;
347+
RANK_IN_POOL = bndpar_group.rank_in_group;
348+
MY_POOL = kpar_group.my_group;
349+
MPI_Comm_dup(bndpar_group.group_comm, &POOL_WORLD);
350+
if(kpar_group.inter_comm != MPI_COMM_NULL)
351+
{
352+
MPI_Comm_dup(kpar_group.inter_comm, &INTER_POOL);
353+
}
354+
else
355+
{
356+
INTER_POOL = MPI_COMM_NULL;
357+
}
358+
359+
if(BNDPAR > 1)
360+
{
361+
NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
362+
RANK_IN_STOGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group;
363+
MY_STOGROUP = bndpar_group.my_group;
364+
MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_STOGROUP, &STO_WORLD);
365+
MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD);
366+
}
367+
else
368+
{
369+
NPROC_IN_STOGROUP = NPROC;
370+
RANK_IN_STOGROUP = MY_RANK;
371+
MY_STOGROUP = 0;
372+
MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD);
373+
MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD);
345374
}
346-
347-
int color = MY_RANK % NPROC_IN_STOGROUP;
348-
MPI_Comm_split(MPI_COMM_WORLD, color, key, &PARAPW_WORLD);
349-
350375
return;
351376
}
352377

@@ -380,31 +405,17 @@ void Parallel_Global::divide_mpi_groups(const int& procs,
380405
exit(1);
381406
}
382407

383-
int* nproc_group_ = new int[num_groups];
384-
385-
for (int i = 0; i < num_groups; i++)
408+
if(rank < extra_procs)
386409
{
387-
nproc_group_[i] = procs_in_group;
388-
if (i < extra_procs)
389-
{
390-
++nproc_group_[i];
391-
}
410+
procs_in_group++;
411+
my_group = rank / procs_in_group;
412+
rank_in_group = rank % procs_in_group;
392413
}
393-
394-
int np_now = 0;
395-
for (int i = 0; i < num_groups; i++)
414+
else
396415
{
397-
np_now += nproc_group_[i];
398-
if (rank < np_now)
399-
{
400-
my_group = i;
401-
procs_in_group = nproc_group_[i];
402-
rank_in_group = rank - (np_now - procs_in_group);
403-
break;
404-
}
416+
my_group = (rank - extra_procs) / procs_in_group;
417+
rank_in_group = (rank - extra_procs) % procs_in_group;
405418
}
406-
407-
delete[] nproc_group_;
408419
}
409420

410421
#endif

source/module_base/parallel_global.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void split_grid_world(const int diag_np, const int& nproc, const int& my_rank, i
4646
*/
4747
void init_pools(const int& NPROC,
4848
const int& MY_RANK,
49-
const int& NSTOGROUP,
49+
const int& BNDPAR,
5050
const int& KPAR,
5151
int& NPROC_IN_STOGROUP,
5252
int& RANK_IN_STOGROUP,
@@ -57,7 +57,7 @@ void init_pools(const int& NPROC,
5757

5858
void divide_pools(const int& NPROC,
5959
const int& MY_RANK,
60-
const int& NSTOGROUP,
60+
const int& BNDPAR,
6161
const int& KPAR,
6262
int& NPROC_IN_STOGROUP,
6363
int& RANK_IN_STOGROUP,

source/module_base/test_parallel/parallel_global_test.cpp

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mpi.h"
55

66
#include "gtest/gtest.h"
7+
#include "gmock/gmock.h"
78
#include <complex>
89
#include <cstring>
910
#include <string>
@@ -165,8 +166,8 @@ TEST_F(ParaGlobal, InitPools)
165166
mpi.kpar = 3;
166167
mpi.nstogroup = 3;
167168
my_rank = 5;
168-
169-
Parallel_Global::init_pools(nproc,
169+
testing::internal::CaptureStdout();
170+
EXPECT_EXIT(Parallel_Global::init_pools(nproc,
170171
my_rank,
171172
mpi.nstogroup,
172173
mpi.kpar,
@@ -175,45 +176,11 @@ TEST_F(ParaGlobal, InitPools)
175176
mpi.my_stogroup,
176177
mpi.nproc_in_pool,
177178
mpi.rank_in_pool,
178-
mpi.my_pool);
179-
EXPECT_EQ(mpi.nproc_in_stogroup, 4);
180-
EXPECT_EQ(mpi.my_stogroup, 1);
181-
EXPECT_EQ(mpi.rank_in_stogroup, 1);
182-
EXPECT_EQ(mpi.my_pool, 0);
183-
EXPECT_EQ(mpi.rank_in_pool, 1);
184-
EXPECT_EQ(mpi.nproc_in_pool, 2);
185-
EXPECT_EQ(MPI_COMM_WORLD != STO_WORLD, true);
186-
EXPECT_EQ(STO_WORLD != POOL_WORLD, true);
187-
EXPECT_EQ(MPI_COMM_WORLD != PARAPW_WORLD, true);
179+
mpi.my_pool), ::testing::ExitedWithCode(1), "");
180+
std::string output = testing::internal::GetCapturedStdout();
181+
EXPECT_THAT(output, testing::HasSubstr("Error:"));
188182
}
189183

190-
TEST_F(ParaGlobal, DividePools)
191-
{
192-
nproc = 12;
193-
mpi.kpar = 3;
194-
mpi.nstogroup = 3;
195-
this->my_rank = 5;
196-
197-
Parallel_Global::divide_pools(nproc,
198-
this->my_rank,
199-
mpi.nstogroup,
200-
mpi.kpar,
201-
mpi.nproc_in_stogroup,
202-
mpi.rank_in_stogroup,
203-
mpi.my_stogroup,
204-
mpi.nproc_in_pool,
205-
mpi.rank_in_pool,
206-
mpi.my_pool);
207-
EXPECT_EQ(mpi.nproc_in_stogroup, 4);
208-
EXPECT_EQ(mpi.my_stogroup, 1);
209-
EXPECT_EQ(mpi.rank_in_stogroup, 1);
210-
EXPECT_EQ(mpi.my_pool, 0);
211-
EXPECT_EQ(mpi.rank_in_pool, 1);
212-
EXPECT_EQ(mpi.nproc_in_pool, 2);
213-
EXPECT_EQ(MPI_COMM_WORLD != STO_WORLD, true);
214-
EXPECT_EQ(STO_WORLD != POOL_WORLD, true);
215-
EXPECT_EQ(MPI_COMM_WORLD != PARAPW_WORLD, true);
216-
}
217184

218185
TEST_F(ParaGlobal, DivideMPIPools)
219186
{

source/module_elecstate/module_charge/charge_mpi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#ifdef __MPI
99
void Charge::init_chgmpi()
1010
{
11-
if (GlobalV::NPROC_IN_STOGROUP % GlobalV::KPAR == 0)
11+
if (INTER_POOL != MPI_COMM_NULL)
1212
{
1313
this->use_intel_pool = true;
1414
}

source/module_elecstate/test_mpi/charge_mpi_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1)
6666
if (GlobalV::NPROC >= 2 && GlobalV::NPROC % 2 == 0)
6767
{
6868
GlobalV::KPAR = 2;
69-
Parallel_Global::divide_pools(GlobalV::NPROC,
69+
Parallel_Global::init_pools(GlobalV::NPROC,
7070
GlobalV::MY_RANK,
7171
PARAM.input.bndpar,
7272
GlobalV::KPAR,

source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void Stochastic_WF<T, Device>::init_sto_orbitals(const int seed_in)
7272
}
7373
else
7474
{
75-
srand((unsigned)std::abs(seed_in) + GlobalV::MY_RANK * 10000);
75+
srand((unsigned)std::abs(seed_in) + (GlobalV::MY_STOGROUP * GlobalV::NPROC_IN_STOGROUP + GlobalV::RANK_IN_STOGROUP) * 10000);
7676
}
7777

7878
this->allocate_chi0();

0 commit comments

Comments
 (0)