Skip to content

Commit 8e9c716

Browse files
committed
refactor: change the order of kpar and bndpar
ABACUS first divide nprocs to kpar group and then divide this groups to bndpar subgroups
1 parent 14b77df commit 8e9c716

File tree

6 files changed

+126
-61
lines changed

6 files changed

+126
-61
lines changed

source/module_base/global_variable.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,9 @@ extern int GSIZE;
4747
extern int KPAR_LCAO;
4848

4949
//==========================================================
50-
// EXPLAIN : readin file dir, output file std::ofstream
51-
// GLOBAL VARIABLES :
52-
// NAME : global_in_card
53-
// NAME : stru_file
54-
// NAME : global_kpoint_card
55-
// NAME : global_wannier_card
56-
// NAME : global_pseudo_dir
57-
// NAME : global_pseudo_type // mohan add 2013-05-20 (xiaohui add 2013-06-23)
58-
// NAME : global_out_dir
5950
// NAME : ofs_running( contain information during runnnig)
6051
// NAME : ofs_warning( contain warning information, including error)
6152
//==========================================================
62-
// extern std::string global_pseudo_type; // mohan add 2013-05-20 (xiaohui add
63-
// 2013-06-23)
6453
extern std::ofstream ofs_running;
6554
extern std::ofstream ofs_warning;
6655
extern std::ofstream ofs_info;

source/module_base/parallel_comm.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#if defined __MPI
22

33
#include "mpi.h"
4+
#include "parallel_global.h"
45

56
MPI_Comm POOL_WORLD;
67
MPI_Comm INTER_POOL = MPI_COMM_NULL; // communicator among different pools
@@ -9,4 +10,46 @@ MPI_Comm PARAPW_WORLD;
910
MPI_Comm GRID_WORLD; // mohan add 2012-01-13
1011
MPI_Comm DIAG_WORLD; // mohan add 2012-01-13
1112

13+
MPICommGroup::MPICommGroup(MPI_Comm parent_comm)
14+
: parent_comm(parent_comm)
15+
{
16+
MPI_Comm_size(parent_comm, &this->gsize);
17+
MPI_Comm_rank(parent_comm, &this->grank);
18+
}
19+
20+
MPICommGroup::~MPICommGroup()
21+
{
22+
if (group_comm != MPI_COMM_NULL)
23+
{
24+
MPI_Comm_free(&group_comm);
25+
}
26+
if (inter_comm != MPI_COMM_NULL)
27+
{
28+
MPI_Comm_free(&inter_comm);
29+
}
30+
}
31+
32+
void MPICommGroup::divide_group_comm(const int& ngroup, const bool assert_even)
33+
{
34+
this->ngroups = ngroup;
35+
Parallel_Global::divide_mpi_groups(this->gsize,
36+
ngroup,
37+
this->grank,
38+
this->nprocs_in_group,
39+
this->my_group,
40+
this->rank_in_group,
41+
assert_even);
42+
43+
MPI_Comm_split(parent_comm, my_group, rank_in_group, &group_comm);
44+
if(this->gsize % ngroup == 0)
45+
{
46+
this->is_even = true;
47+
}
48+
49+
if (this->is_even)
50+
{
51+
MPI_Comm_split(parent_comm, my_inter, rank_in_inter, &inter_comm);
52+
}
53+
}
54+
1255
#endif

source/module_base/parallel_comm.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define PARALLEL_COMM_H
33

44
#ifdef __MPI
5-
65
#include "mpi.h"
76
extern MPI_Comm POOL_WORLD;
87
extern MPI_Comm INTER_POOL; // communicator among different pools
@@ -11,6 +10,33 @@ extern MPI_Comm PARAPW_WORLD;
1110
extern MPI_Comm GRID_WORLD; // mohan add 2012-01-13
1211
extern MPI_Comm DIAG_WORLD; // mohan add 2012-01-13
1312

13+
14+
class MPICommGroup
15+
{
16+
public:
17+
MPICommGroup(MPI_Comm parent_comm);
18+
~MPICommGroup();
19+
void divide_group_comm(const int& ngroup, const bool assert_even = true);
20+
public:
21+
bool is_even = false; ///< whether the group is even
22+
23+
MPI_Comm parent_comm = MPI_COMM_NULL; ///< parent communicator
24+
int gsize = 0; ///< size of parent communicator
25+
int grank = 0; ///< rank of parent communicator
26+
27+
MPI_Comm group_comm = MPI_COMM_NULL; ///< group communicator
28+
int ngroups = 0; ///< number of groups
29+
int nprocs_in_group = 0; ///< number of processes in the group
30+
int my_group = 0; ///< the group index
31+
int rank_in_group = 0; ///< the rank in the group
32+
33+
MPI_Comm inter_comm = MPI_COMM_NULL; ///< inter communicator
34+
bool has_inter_comm = false; ///< whether has inter communicator
35+
int& nprocs_in_inter = ngroups; ///< number of processes in the inter communicator
36+
int& my_inter = rank_in_group; ///< the rank in the inter communicator
37+
int& rank_in_inter = my_group; ///< the inter group index
38+
};
39+
1440
#endif
1541

1642
#endif // PARALLEL_COMM_H

source/module_base/parallel_global.cpp

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ void Parallel_Global::finalize_mpi()
251251

252252
void Parallel_Global::init_pools(const int& NPROC,
253253
const int& MY_RANK,
254-
const int& NSTOGROUP,
254+
const int& BNDPAR,
255255
const int& KPAR,
256256
int& NPROC_IN_STOGROUP,
257257
int& RANK_IN_STOGROUP,
@@ -266,7 +266,7 @@ void Parallel_Global::init_pools(const int& NPROC,
266266
//----------------------------------------------------------
267267
Parallel_Global::divide_pools(NPROC,
268268
MY_RANK,
269-
NSTOGROUP,
269+
BNDPAR,
270270
KPAR,
271271
NPROC_IN_STOGROUP,
272272
RANK_IN_STOGROUP,
@@ -314,7 +314,7 @@ void Parallel_Global::init_pools(const int& NPROC,
314314
#ifdef __MPI
315315
void Parallel_Global::divide_pools(const int& NPROC,
316316
const int& MY_RANK,
317-
const int& NSTOGROUP,
317+
const int& BNDPAR,
318318
const int& KPAR,
319319
int& NPROC_IN_STOGROUP,
320320
int& RANK_IN_STOGROUP,
@@ -323,30 +323,51 @@ void Parallel_Global::divide_pools(const int& NPROC,
323323
int& RANK_IN_POOL,
324324
int& MY_POOL)
325325
{
326-
// Divide the global communicator into stogroups.
327-
divide_mpi_groups(NPROC, NSTOGROUP, MY_RANK, NPROC_IN_STOGROUP, MY_STOGROUP, RANK_IN_STOGROUP, true);
328-
329-
// (2) per process in each pool
330-
divide_mpi_groups(NPROC_IN_STOGROUP, KPAR, RANK_IN_STOGROUP, NPROC_IN_POOL, MY_POOL, RANK_IN_POOL);
331-
332-
int key = 1;
333-
MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, key, &STO_WORLD);
334-
335-
//========================================================
336-
// MPI_Comm_Split: Creates new communicators based on
337-
// colors(2nd parameter) and keys(3rd parameter)
338-
// Note: The color must be non-negative or MPI_UNDEFINED.
339-
//========================================================
340-
MPI_Comm_split(STO_WORLD, MY_POOL, key, &POOL_WORLD);
341-
342-
if (NPROC_IN_STOGROUP % KPAR == 0)
326+
// note: the order of k-point parallelization and band parallelization is important
327+
// The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL
328+
// and MY_STOGROUP will be the same as well.
329+
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
343330
{
344-
MPI_Comm_split(STO_WORLD, RANK_IN_POOL, key, &INTER_POOL);
331+
std::cout << "Error: When " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
332+
<< BNDPAR * KPAR << ")." << std::endl;
333+
exit(1);
334+
}
335+
// k-point parallelization
336+
MPICommGroup kpar_group(MPI_COMM_WORLD);
337+
kpar_group.divide_group_comm(KPAR, false);
338+
339+
// band parallelization
340+
MPICommGroup bndpar_group(kpar_group.group_comm);
341+
bndpar_group.divide_group_comm(BNDPAR, true);
342+
343+
// Set parallel index.
344+
// In previous versions, the order of k-point parallelization and band parallelization is reversed.
345+
// So the following code is kept for compatibility.
346+
NPROC_IN_POOL = bndpar_group.nprocs_in_group;
347+
RANK_IN_POOL = bndpar_group.rank_in_group;
348+
MY_POOL = kpar_group.my_group;
349+
// INTER_POOL = kpar_group.inter_comm;
350+
// POOL_WORLD = bndpar_group.group_comm;
351+
MPI_Comm_dup(bndpar_group.group_comm, &POOL_WORLD);
352+
MPI_Comm_dup(kpar_group.inter_comm, &INTER_POOL);
353+
354+
if (KPAR > 1)
355+
{
356+
NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
357+
RANK_IN_STOGROUP = kpar_group.rank_in_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group;
358+
MY_STOGROUP = bndpar_group.my_group;
359+
MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_STOGROUP, &STO_WORLD);
360+
MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD);
361+
}
362+
else
363+
{
364+
NPROC_IN_STOGROUP = NPROC;
365+
RANK_IN_STOGROUP = MY_RANK;
366+
MY_STOGROUP = 0;
367+
// STO_WORLD = MPI_COMM_WORLD;
368+
MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD);
369+
MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD);
345370
}
346-
347-
int color = MY_RANK % NPROC_IN_STOGROUP;
348-
MPI_Comm_split(MPI_COMM_WORLD, color, key, &PARAPW_WORLD);
349-
350371
return;
351372
}
352373

@@ -380,31 +401,17 @@ void Parallel_Global::divide_mpi_groups(const int& procs,
380401
exit(1);
381402
}
382403

383-
int* nproc_group_ = new int[num_groups];
384-
385-
for (int i = 0; i < num_groups; i++)
404+
if(rank < extra_procs)
386405
{
387-
nproc_group_[i] = procs_in_group;
388-
if (i < extra_procs)
389-
{
390-
++nproc_group_[i];
391-
}
406+
procs_in_group++;
407+
my_group = rank / procs_in_group;
408+
rank_in_group = rank % procs_in_group;
392409
}
393-
394-
int np_now = 0;
395-
for (int i = 0; i < num_groups; i++)
410+
else
396411
{
397-
np_now += nproc_group_[i];
398-
if (rank < np_now)
399-
{
400-
my_group = i;
401-
procs_in_group = nproc_group_[i];
402-
rank_in_group = rank - (np_now - procs_in_group);
403-
break;
404-
}
412+
my_group = (rank - extra_procs) / procs_in_group;
413+
rank_in_group = (rank - extra_procs) % procs_in_group;
405414
}
406-
407-
delete[] nproc_group_;
408415
}
409416

410417
#endif

source/module_base/parallel_global.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void split_grid_world(const int diag_np, const int& nproc, const int& my_rank, i
4646
*/
4747
void init_pools(const int& NPROC,
4848
const int& MY_RANK,
49-
const int& NSTOGROUP,
49+
const int& BNDPAR,
5050
const int& KPAR,
5151
int& NPROC_IN_STOGROUP,
5252
int& RANK_IN_STOGROUP,
@@ -57,7 +57,7 @@ void init_pools(const int& NPROC,
5757

5858
void divide_pools(const int& NPROC,
5959
const int& MY_RANK,
60-
const int& NSTOGROUP,
60+
const int& BNDPAR,
6161
const int& KPAR,
6262
int& NPROC_IN_STOGROUP,
6363
int& RANK_IN_STOGROUP,

source/module_elecstate/module_charge/charge_mpi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#ifdef __MPI
99
void Charge::init_chgmpi()
1010
{
11-
if (GlobalV::NPROC_IN_STOGROUP % GlobalV::KPAR == 0)
11+
if (INTER_POOL != MPI_COMM_NULL)
1212
{
1313
this->use_intel_pool = true;
1414
}

0 commit comments

Comments
 (0)