Skip to content

Commit c15c823

Browse files
committed
make bpcg support bndpar > 1
1 parent f789080 commit c15c823

38 files changed

+809
-601
lines changed

source/module_base/para_gemm.cpp

Lines changed: 261 additions & 124 deletions
Large diffs are not rendered by default.

source/module_base/para_gemm.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class PGemmCN
3636
* @param LDB leading dimension of B in each proc
3737
* @param nrow number of rows of A or B
3838
* @param LDC leading dimension of C. C can be C_local or C_global
39-
* @param gatherC whether gather C_local to C_global
39+
* @param mode 1: gather C_local to C_global, 2:C_local(nrow * ncol_loc), 3:C_global(nrow_loc * ncol)
4040
*/
4141
void set_dimension(
4242
#ifdef __MPI
@@ -49,7 +49,7 @@ class PGemmCN
4949
const int LDB,
5050
const int nrow,
5151
const int LDC,
52-
const bool gatherC = true);
52+
const int mode = 1);
5353

5454
/**
5555
* @brief calculate C = alpha * A^H * B + beta * C
@@ -67,14 +67,16 @@ class PGemmCN
6767

6868
std::vector<int> colA_loc; ///< [col_nproc] number of columns of A matrix in each proc
6969
int max_colA = 0; ///< maximum number of columns of A matrix in all procs
70-
std::vector<int> colB_loc; ///<[col_nproc] number of columns of B matrix in each proc
70+
std::vector<int> colB_loc; ///< [col_nproc] number of columns of B matrix in each proc
71+
int max_colB = 0; ///< maximum number of columns of B matrix in all procs
7172

7273
std::vector<MPI_Request> requests; ///< MPI request
7374
std::vector<int> recv_counts; ///< receive counts for gathering C_local to C_global
7475
std::vector<int> displs; ///< displacements for gathering C_local to C_global
7576
int size_C_local = 0; ///< size of C_local, which is a local matrix in each proc
7677
int size_C_global = 0; ///< size of C_global, which is the global C matrix gathered from all procs
7778
bool gatherC = true; ///< whether gather C_local to C_global
79+
bool divideCrow = false; ///< whether divide C_global to C_local
7880
#endif
7981
int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc
8082
int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc
@@ -83,6 +85,14 @@ class PGemmCN
8385
int LDB = 0; ///< leading dimension of B in each proc
8486
int LDC = 0; ///< leading dimension of C, which can be C_local or C_global
8587
private:
88+
/// @brief for col_nproc == 1
89+
void multiply_single(const T alpha, const T* A, const T* B, const T beta, T* C);
90+
#ifdef __MPI
91+
/// @brief for mode = 1 or 2
92+
void multiply_col(const T alpha, const T* A, const T* B, const T beta, T* C);
93+
/// @brief for mode = 3
94+
void multiply_row(const T alpha, const T* A, const T* B, const T beta, T* C);
95+
#endif
8696
using resmem_dev_op = base_device::memory::resize_memory_op<T, Device>;
8797
using delmem_dev_op = base_device::memory::delete_memory_op<T, Device>;
8898
using syncmem_dev_op = base_device::memory::synchronize_memory_op<T, Device, Device>;

source/module_base/parallel_comm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
#include "mpi.h"
44
#include "parallel_global.h"
55

6-
MPI_Comm POOL_WORLD;
7-
MPI_Comm INTER_POOL; // communicator among different pools
8-
MPI_Comm STO_WORLD;
9-
MPI_Comm PARAPW_WORLD;
6+
MPI_Comm POOL_WORLD; //groups for different plane waves. In this group, only plane waves are different. K-points and bands are the same.
7+
MPI_Comm KP_WORLD; // groups for differnt k. In this group, only k-points are different. Bands and plane waves are the same.
8+
MPI_Comm BP_WORLD; // groups for differnt bands. In this group, only bands are different. K-points and plane waves are the same.
9+
MPI_Comm INT_BGROUP; // internal comm groups for same bands. In this group, only bands are the same. K-points and plane waves are different.
1010
MPI_Comm GRID_WORLD; // mohan add 2012-01-13
1111
MPI_Comm DIAG_WORLD; // mohan add 2012-01-13
1212

source/module_base/parallel_comm.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
#ifdef __MPI
55
#include "mpi.h"
66
extern MPI_Comm POOL_WORLD;
7-
extern MPI_Comm INTER_POOL; // communicator among different pools
8-
extern MPI_Comm STO_WORLD;
9-
extern MPI_Comm PARAPW_WORLD;
7+
extern MPI_Comm KP_WORLD; // communicator among different pools
8+
extern MPI_Comm INT_BGROUP;
9+
extern MPI_Comm BP_WORLD;
1010
extern MPI_Comm GRID_WORLD; // mohan add 2012-01-13
1111
extern MPI_Comm DIAG_WORLD; // mohan add 2012-01-13
1212

source/module_base/parallel_global.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,12 @@ void Parallel_Global::read_pal_param(int argc,
237237
void Parallel_Global::finalize_mpi()
238238
{
239239
MPI_Comm_free(&POOL_WORLD);
240-
if (INTER_POOL != MPI_COMM_NULL)
240+
if (KP_WORLD != MPI_COMM_NULL)
241241
{
242-
MPI_Comm_free(&INTER_POOL);
242+
MPI_Comm_free(&KP_WORLD);
243243
}
244-
MPI_Comm_free(&STO_WORLD);
245-
MPI_Comm_free(&PARAPW_WORLD);
244+
MPI_Comm_free(&INT_BGROUP);
245+
MPI_Comm_free(&BP_WORLD);
246246
MPI_Comm_free(&GRID_WORLD);
247247
MPI_Comm_free(&DIAG_WORLD);
248248
MPI_Finalize();
@@ -324,7 +324,7 @@ void Parallel_Global::divide_pools(const int& NPROC,
324324
int& MY_POOL)
325325
{
326326
// note: the order of k-point parallelization and band parallelization is important
327-
// The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL
327+
// The order will not change the behavior of KP_WORLD or BP_WORLD, and MY_POOL
328328
// and MY_BNDGROUP will be the same as well.
329329
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
330330
{
@@ -349,28 +349,28 @@ void Parallel_Global::divide_pools(const int& NPROC,
349349
MPI_Comm_dup(bndpar_group.group_comm, &POOL_WORLD);
350350
if(kpar_group.inter_comm != MPI_COMM_NULL)
351351
{
352-
MPI_Comm_dup(kpar_group.inter_comm, &INTER_POOL);
352+
MPI_Comm_dup(kpar_group.inter_comm, &KP_WORLD);
353353
}
354354
else
355355
{
356-
INTER_POOL = MPI_COMM_NULL;
356+
KP_WORLD = MPI_COMM_NULL;
357357
}
358358

359359
if(BNDPAR > 1)
360360
{
361361
NPROC_IN_BNDGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
362362
RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group;
363363
MY_BNDGROUP = bndpar_group.my_group;
364-
MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &STO_WORLD);
365-
MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD);
364+
MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &INT_BGROUP);
365+
MPI_Comm_dup(bndpar_group.inter_comm, &BP_WORLD);
366366
}
367367
else
368368
{
369369
NPROC_IN_BNDGROUP = NPROC;
370370
RANK_IN_BPGROUP = MY_RANK;
371371
MY_BNDGROUP = 0;
372-
MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD);
373-
MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD);
372+
MPI_Comm_dup(MPI_COMM_WORLD, &INT_BGROUP);
373+
MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &BP_WORLD);
374374
}
375375
return;
376376
}

source/module_base/parallel_reduce.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ void Parallel_Reduce::reduce_pool<double>(double* object, const int n)
110110

111111
// (1) the value is same in each pool.
112112
// (2) we need to reduce the value from different pool.
113-
void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double& object)
113+
void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object)
114114
{
115-
if (kpar == 1)
115+
if (npool == 1)
116116
{
117117
return;
118118
}
@@ -124,9 +124,9 @@ void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in
124124

125125
// (1) the value is same in each pool.
126126
// (2) we need to reduce the value from different pool.
127-
void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double* object, const int n)
127+
void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n)
128128
{
129-
if (kpar == 1)
129+
if (npool == 1)
130130
{
131131
return;
132132
}

source/module_base/parallel_reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ void reduce_int_grid(int* object, const int n); // mohan add 2012-01-12
3131
void reduce_double_grid(double* object, const int n);
3232
void reduce_double_diag(double* object, const int n);
3333

34-
void reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double& object);
35-
void reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double* object, const int n);
34+
void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object);
35+
void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n);
3636

3737
void gather_min_int_all(const int& nproc, int& v);
3838
void gather_max_double_all(const int& nproc, double& v);

source/module_base/test_parallel/test_para_gemm.cpp

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,49 @@ TYPED_TEST(PgemmTest, odd_case)
367367
this->compare_result(ncolA_global, ncolB_global, LDC_global);
368368
}
369369

370-
TYPED_TEST(PgemmTest, odd_case_not_gather)
370+
TYPED_TEST(PgemmTest, row_parallel)
371+
{
372+
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;
373+
const int LDA_global = 17, LDB_global = 18, LDC_global = 19;
374+
375+
this->decide_ngroup(1, 4);
376+
this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
377+
378+
this->pgemm.set_dimension(this->col_world,
379+
this->row_world,
380+
this->ncolA,
381+
this->LDA,
382+
this->ncolB,
383+
this->LDB,
384+
this->nrow,
385+
LDC_global);
386+
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data());
387+
388+
this->compare_result(ncolA_global, ncolB_global, LDC_global);
389+
}
390+
391+
TYPED_TEST(PgemmTest, col_parallel)
392+
{
393+
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;
394+
const int LDA_global = 17, LDB_global = 18, LDC_global = 19;
395+
396+
this->decide_ngroup(4, 1);
397+
this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
398+
399+
this->pgemm.set_dimension(this->col_world,
400+
this->row_world,
401+
this->ncolA,
402+
this->LDA,
403+
this->ncolB,
404+
this->LDB,
405+
this->nrow,
406+
LDC_global);
407+
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data());
408+
409+
this->compare_result(ncolA_global, ncolB_global, LDC_global);
410+
}
411+
412+
TYPED_TEST(PgemmTest, divide_col)
371413
{
372414
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;
373415
const int LDA_global = 17, LDB_global = 18, LDC_global = 19;
@@ -392,7 +434,7 @@ TYPED_TEST(PgemmTest, odd_case_not_gather)
392434
this->LDB,
393435
this->nrow,
394436
LDC_global,
395-
false);
437+
2);
396438
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()+ start);
397439

398440

@@ -408,34 +450,32 @@ TYPED_TEST(PgemmTest, odd_case_not_gather)
408450
}
409451
}
410452

411-
TYPED_TEST(PgemmTest, row_parallel)
453+
TYPED_TEST(PgemmTest, divide_row)
412454
{
413455
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;
414456
const int LDA_global = 17, LDB_global = 18, LDC_global = 19;
415457

416-
this->decide_ngroup(1, 4);
458+
this->decide_ngroup(2, 2);
417459
this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
460+
std::vector<int> colA_loc(this->nproc_col);
461+
MPI_Allgather(&this->ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, this->col_world);
462+
std::vector<int> displs(this->nproc_col);
463+
displs[0] = 0;
464+
for (int i = 1; i < this->nproc_col; i++)
465+
{
466+
displs[i] = (displs[i - 1] + colA_loc[i - 1]);
467+
}
468+
int start = displs[this->rank_col];
418469

419-
this->pgemm.set_dimension(this->col_world,
420-
this->row_world,
421-
this->ncolA,
422-
this->LDA,
423-
this->ncolB,
424-
this->LDB,
425-
this->nrow,
426-
LDC_global);
427-
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data());
428-
429-
this->compare_result(ncolA_global, ncolB_global, LDC_global);
430-
}
431-
432-
TYPED_TEST(PgemmTest, col_parallel)
433-
{
434-
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;
435-
const int LDA_global = 17, LDB_global = 18, LDC_global = 19;
436-
437-
this->decide_ngroup(4, 1);
438-
this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
470+
int LDC_local = this->ncolA + 2;
471+
std::vector<TypeParam> C_loc(LDC_local * ncolB_global, 0.0);
472+
for(int i = 0; i < ncolB_global; i++)
473+
{
474+
for(int j = 0; j < this->ncolA; j++)
475+
{
476+
C_loc[i * LDC_local + j] = this->C_global[i * LDC_global + start + j];
477+
}
478+
}
439479

440480
this->pgemm.set_dimension(this->col_world,
441481
this->row_world,
@@ -444,10 +484,21 @@ TYPED_TEST(PgemmTest, col_parallel)
444484
this->ncolB,
445485
this->LDB,
446486
this->nrow,
447-
LDC_global);
448-
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data());
487+
LDC_local,
488+
3);
489+
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, C_loc.data());
449490

450-
this->compare_result(ncolA_global, ncolB_global, LDC_global);
491+
492+
493+
for (int i = 0; i < ncolB_global; i++)
494+
{
495+
for (int j = 0; j < this->ncolA; j++)
496+
{
497+
EXPECT_NEAR(get_double(this->Cref_global[i * LDC_global + start + j]),
498+
get_double(C_loc[i * LDC_local + j]),
499+
1e-10);
500+
}
501+
}
451502
}
452503

453504
int main(int argc, char** argv)

source/module_elecstate/elecstate.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ void ElecState::calculate_weights()
163163
this->klist->isk);
164164
}
165165
#ifdef __MPI
166-
// qianrui fix a bug on 2021-7-21
167-
Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, this->f_en.demet);
166+
const int npool = GlobalV::KPAR * PARAM.inp.bndpar;
167+
Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.demet);
168168
#endif
169169
}
170170
else if (Occupy::fixed_occupations)
@@ -192,16 +192,11 @@ void ElecState::calEBand()
192192
}
193193
}
194194
this->f_en.eband = eband;
195-
if (GlobalV::KPAR != 1 && PARAM.inp.esolver_type != "sdft")
196-
{
197-
//==================================
198-
// Reduce all the Energy in each cpu
199-
//==================================
200-
this->f_en.eband /= GlobalV::NPROC_IN_POOL;
195+
201196
#ifdef __MPI
202-
Parallel_Reduce::reduce_all(this->f_en.eband);
197+
const int npool = GlobalV::KPAR * PARAM.inp.bndpar;
198+
Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.eband);
203199
#endif
204-
}
205200
return;
206201
}
207202

@@ -253,8 +248,8 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge
253248
// init nelec_spin with nelec and nupdown
254249
this->init_nelec_spin();
255250
// initialize ekb and wg
256-
this->ekb.create(nk_in, PARAM.inp.nbands);
257-
this->wg.create(nk_in, PARAM.inp.nbands);
251+
this->ekb.create(nk_in, PARAM.globalv.nbands_l);
252+
this->wg.create(nk_in, PARAM.globalv.nbands_l);
258253
}
259254

260255
} // namespace elecstate

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
1919
setmem_var_op()(this->rho[is], 0, this->charge->nrxx);
2020
}
2121

22-
if (GlobalV::MY_BNDGROUP == 0 || PARAM.inp.ks_solver == "bpcg")
22+
if (PARAM.globalv.ks_run)
2323
{
2424
for (int ik = 0; ik < psi.get_nk(); ++ik)
2525
{

0 commit comments

Comments
 (0)