Skip to content

Commit 09dcf7d

Browse files
committed
add change
1 parent 140b272 commit 09dcf7d

File tree

27 files changed

+246
-149
lines changed

27 files changed

+246
-149
lines changed

source/source_base/parallel_reduce.cpp

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,24 @@ void Parallel_Reduce::reduce_pool<double>(double* object, const int n)
126126
return;
127127
}
128128

129+
template <>
130+
void Parallel_Reduce::reduce_pool<int>(int& object)
131+
{
132+
#ifdef __MPI
133+
MPI_Allreduce(MPI_IN_PLACE, &object, 1, MPI_INT, MPI_SUM, POOL_WORLD);
134+
#endif
135+
return;
136+
}
137+
138+
template <>
139+
void Parallel_Reduce::reduce_pool<int>(int* object, const int n)
140+
{
141+
#ifdef __MPI
142+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_INT, MPI_SUM, POOL_WORLD);
143+
#endif
144+
return;
145+
}
146+
129147
// (1) the value is same in each pool.
130148
// (2) we need to reduce the value from different pool.
131149
void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object)
@@ -314,4 +332,95 @@ void Parallel_Reduce::gather_min_double_all(const int& nproc, double& v)
314332
}
315333
}
316334
#endif
317-
}
335+
}
336+
337+
void Parallel_Reduce::gather_max_int_all(const int& nproc, int& v)
338+
{
339+
#ifdef __MPI
340+
std::vector<int> value(nproc, 0);
341+
MPI_Allgather(&v, 1, MPI_INT, value.data(), 1, MPI_INT, MPI_COMM_WORLD);
342+
for (int i = 0; i < nproc; i++)
343+
{
344+
if (v < value[i])
345+
{
346+
v = value[i];
347+
}
348+
}
349+
#endif
350+
}
351+
352+
void Parallel_Reduce::gather_max_int_pool(const int& nproc_in_pool, int& v)
353+
{
354+
#ifdef __MPI
355+
if (nproc_in_pool == 1)
356+
{
357+
return;
358+
}
359+
std::vector<int> value(nproc_in_pool, 0);
360+
MPI_Allgather(&v, 1, MPI_INT, value.data(), 1, MPI_INT, POOL_WORLD);
361+
for (int i = 0; i < nproc_in_pool; i++)
362+
{
363+
if (v < value[i])
364+
{
365+
v = value[i];
366+
}
367+
}
368+
#endif
369+
}
370+
void Parallel_Reduce::gather_or_bool_all(bool& v)
371+
{
372+
#ifdef __MPI
373+
MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_C_BOOL, MPI_LOR, MPI_COMM_WORLD);
374+
#endif
375+
}
376+
377+
void Parallel_Reduce::gather_or_bool_bp(bool& v)
378+
{
379+
#ifdef __MPI
380+
MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_C_BOOL, MPI_LOR, BP_WORLD);
381+
#endif
382+
}
383+
384+
void Parallel_Reduce::reduce_kp(double* object, const int n)
385+
{
386+
#ifdef __MPI
387+
if (KP_WORLD != MPI_COMM_NULL)
388+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, KP_WORLD);
389+
#endif
390+
}
391+
392+
void Parallel_Reduce::reduce_bp(double* object, const int n)
393+
{
394+
#ifdef __MPI
395+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, BP_WORLD);
396+
#endif
397+
}
398+
399+
void Parallel_Reduce::reduce_bgroup(double* object, const int n)
400+
{
401+
#ifdef __MPI
402+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, INT_BGROUP);
403+
#endif
404+
}
405+
406+
void Parallel_Reduce::reduce_kp(int* object, const int n)
407+
{
408+
#ifdef __MPI
409+
if (KP_WORLD != MPI_COMM_NULL)
410+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_INT, MPI_SUM, KP_WORLD);
411+
#endif
412+
}
413+
414+
void Parallel_Reduce::reduce_bp(int* object, const int n)
415+
{
416+
#ifdef __MPI
417+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_INT, MPI_SUM, BP_WORLD);
418+
#endif
419+
}
420+
421+
void Parallel_Reduce::reduce_bgroup(int* object, const int n)
422+
{
423+
#ifdef __MPI
424+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_INT, MPI_SUM, INT_BGROUP);
425+
#endif
426+
}

source/source_base/parallel_reduce.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ void gather_min_double_all(const int& nproc, double& v);
4040
void gather_max_double_pool(const int& nproc_in_pool, double& v);
4141
void gather_min_double_pool(const int& nproc_in_pool, double& v);
4242

43+
void gather_max_int_all(const int& nproc, int& v);
44+
void gather_max_int_pool(const int& nproc_in_pool, int& v);
45+
46+
void gather_or_bool_all(bool& v);
47+
void gather_or_bool_bp(bool& v);
48+
49+
void reduce_kp(double* object, const int n);
50+
void reduce_bp(double* object, const int n);
51+
void reduce_bgroup(double* object, const int n);
52+
53+
void reduce_kp(int* object, const int n);
54+
void reduce_bp(int* object, const int n);
55+
void reduce_bgroup(int* object, const int n);
56+
4357
// mohan add 2011-04-21
4458
void gather_int_all(int& v, int* all);
4559

source/source_basis/module_pw/pw_basis_big.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define PW_BASIS_BIG_H
33
#include "source_base/constants.h"
44
#include "source_base/global_function.h"
5+
#include "source_base/parallel_reduce.h" // Parallel_Reduce
6+
#include "source_base/parallel_global.h" // GlobalV
57
#ifdef __MPI
68
#include "mpi.h"
79
#endif
@@ -166,9 +168,9 @@ class PW_Basis_Big : public PW_Basis_Sup
166168
ibox[0] = 2*n1+1;
167169
ibox[1] = 2*n2+1;
168170
ibox[2] = 2*n3+1;
169-
#ifdef __MPI
170-
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
171-
#endif
171+
Parallel_Reduce::gather_max_int_pool(GlobalV::NPROC_IN_POOL, ibox[0]);
172+
Parallel_Reduce::gather_max_int_pool(GlobalV::NPROC_IN_POOL, ibox[1]);
173+
Parallel_Reduce::gather_max_int_pool(GlobalV::NPROC_IN_POOL, ibox[2]);
172174

173175
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
174176
for (int i = 0; i < 3; i++)
@@ -349,9 +351,7 @@ class PW_Basis_Big : public PW_Basis_Sup
349351
}
350352
}
351353
}
352-
#ifdef __MPI
353-
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
354-
#endif
354+
Parallel_Reduce::gather_min_double_pool(GlobalV::NPROC_IN_POOL, this->gridecut_lat);
355355
this->gridecut_lat -= 1e-6;
356356

357357
delete[] ibox;

source/source_basis/module_pw/pw_init.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "pw_basis.h"
2-
#include "source_base/constants.h"
2+
#include "source_base/parallel_global.h" // GlobalV
3+
#include "source_base/parallel_reduce.h"
34

45
namespace ModulePW
56
{
@@ -85,9 +86,9 @@ void PW_Basis:: initgrids(
8586
ibox[0] = 2*n1+1;
8687
ibox[1] = 2*n2+1;
8788
ibox[2] = 2*n3+1;
88-
#ifdef __MPI
89-
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
90-
#endif
89+
Parallel_Reduce::gather_max_int_pool(GlobalV::NPROC_IN_POOL, ibox[0]);
90+
Parallel_Reduce::gather_max_int_pool(GlobalV::NPROC_IN_POOL, ibox[1]);
91+
Parallel_Reduce::gather_max_int_pool(GlobalV::NPROC_IN_POOL, ibox[2]);
9192

9293
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
9394
for (int i = 0; i < 3; i++)
@@ -199,9 +200,7 @@ void PW_Basis:: initgrids(
199200
}
200201
}
201202
}
202-
#ifdef __MPI
203-
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
204-
#endif
203+
Parallel_Reduce::gather_min_double_pool(GlobalV::NPROC_IN_POOL, this->gridecut_lat);
205204
this->gridecut_lat -= 1e-6;
206205

207206
delete[] ibox;

source/source_cell/parallel_kpoints.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ void Parallel_Kpoints::gatherkvec(const std::vector<ModuleBase::Vector3<double>>
124124
}
125125
}
126126

127-
MPI_Allreduce(MPI_IN_PLACE, &vec_global[0], 3 * this->nkstot_np, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
127+
#include "source_base/parallel_reduce.h"
128+
129+
Parallel_Reduce::reduce_all(&vec_global[0], 3 * this->nkstot_np);
128130
return;
129131
}
130132
#endif

source/source_estate/elecstate_energy.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,8 @@ double ElecState::cal_delta_eband(const UnitCell& ucell) const
171171
}
172172
}
173173

174-
#ifdef __MPI
175-
MPI_Allreduce(&deband_aux, &deband0, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
176-
#else
177174
deband0 = deband_aux;
178-
#endif
175+
Parallel_Reduce::reduce_pool(deband0);
179176

180177
deband0 *= ucell.omega / this->charge->rhopw->nxyz;
181178

source/source_estate/module_charge/charge_mpi.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void Charge::reduce_diff_pools(double* array_rho) const
3030
ModuleBase::timer::tick("Charge", "reduce_diff_pools");
3131
if (KP_WORLD != MPI_COMM_NULL)
3232
{
33-
MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, KP_WORLD);
33+
Parallel_Reduce::reduce_kp(array_rho, this->nrxx);
3434
}
3535
else
3636
{
@@ -92,7 +92,11 @@ void Charge::reduce_diff_pools(double* array_rho) const
9292
//==================================
9393
// Reduce all the rho in each cpu
9494
//==================================
95-
MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, INT_BGROUP);
95+
for (int i = 0; i < this->rhopw->nxyz; i++)
96+
{
97+
array_tot[i] = array_tot_aux[i];
98+
}
99+
Parallel_Reduce::reduce_bgroup(array_tot, this->rhopw->nxyz);
96100

97101
//=====================================
98102
// Change the order of rho in each cpu
@@ -111,7 +115,7 @@ void Charge::reduce_diff_pools(double* array_rho) const
111115
}
112116
if(PARAM.globalv.all_ks_run && PARAM.inp.bndpar > 1)
113117
{
114-
MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, BP_WORLD);
118+
Parallel_Reduce::reduce_bp(array_rho, this->nrxx);
115119
}
116120
ModuleBase::timer::tick("Charge", "reduce_diff_pools");
117121
}

source/source_estate/module_charge/symmetry_rhog.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "symmetry_rho.h"
22
#include "source_pw/module_pwdft/global.h"
33
#include "source_base/parallel_global.h"
4+
#include "source_base/parallel_reduce.h"
45
#include "source_hamilt/module_xc/xc_functional.h"
56

67

@@ -9,8 +10,8 @@ void Symmetry_rho::psymmg(std::complex<double>* rhog_part, const ModulePW::PW_Ba
910
//(1) get fftixy2is and do Allreduce
1011
int * fftixy2is = new int [rho_basis->fftnxy];
1112
rho_basis->getfftixy2is(fftixy2is); //current proc
13+
Parallel_Reduce::reduce_pool(fftixy2is, rho_basis->fftnxy);
1214
#ifdef __MPI
13-
MPI_Allreduce(MPI_IN_PLACE, fftixy2is, rho_basis->fftnxy, MPI_INT, MPI_SUM, POOL_WORLD);
1415
if(rho_basis->poolnproc>1)
1516
for (int i=0;i<rho_basis->fftnxy;++i)
1617
fftixy2is[i]+=rho_basis->poolnproc-1;

source/source_hsolver/diago_bpcg.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "source_base/global_function.h"
55
#include "source_base/kernels/math_kernel_op.h"
66
#include "source_base/parallel_comm.h" // different MPI worlds
7+
#include "source_base/parallel_reduce.h"
78
#include "source_hsolver/kernels/bpcg_kernel_op.h"
89
#include "para_linear_transform.h"
910

@@ -85,9 +86,7 @@ bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, const std::vecto
8586
not_conv = true;
8687
}
8788
}
88-
#ifdef __MPI
89-
MPI_Allreduce(MPI_IN_PLACE, &not_conv, 1, MPI_C_BOOL, MPI_LOR, BP_WORLD);
90-
#endif
89+
Parallel_Reduce::gather_or_bool_bp(not_conv);
9190
return not_conv;
9291
}
9392

source/source_io/output_log.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "source_base/global_variable.h"
77

88
#include "source_base/parallel_comm.h"
9+
#include "source_base/parallel_reduce.h"
910

1011
#ifdef __MPI
1112
#include <mpi.h>
@@ -154,7 +155,7 @@ void output_vacuum_level(const UnitCell* ucell,
154155
}
155156

156157
#ifdef __MPI
157-
MPI_Allreduce(MPI_IN_PLACE, ave, length, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
158+
Parallel_Reduce::reduce_pool(ave, length);
158159
#endif
159160

160161
int surface = nxyz / length;

0 commit comments

Comments
 (0)