Skip to content

Commit 6157da8

Browse files
committed
Refactor: encapsulate MPI_Allreduce(MPI_IN_PLACE) calls with Parallel_Reduce in source/
1 parent f1290e9 commit 6157da8

File tree

6 files changed

+22
-17
lines changed

6 files changed

+22
-17
lines changed

source/source_io/output_log.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "source_base/global_variable.h"
77

88
#include "source_base/parallel_comm.h"
9+
#include "source_base/parallel_reduce.h"
910

1011
#ifdef __MPI
1112
#include <mpi.h>
@@ -154,7 +155,7 @@ void output_vacuum_level(const UnitCell* ucell,
154155
}
155156

156157
#ifdef __MPI
157-
MPI_Allreduce(MPI_IN_PLACE, ave, length, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
158+
Parallel_Reduce::reduce_pool(ave, length);
158159
#endif
159160

160161
int surface = nxyz / length;

source/source_pw/module_pwdft/VNL_in_pw.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "source_base/math_ylmreal.h"
1111
#include "source_base/memory.h"
1212
#include "source_base/module_device/device.h"
13+
#include "source_base/parallel_reduce.h"
1314
#include "source_base/timer.h"
1415
#include "source_pw/module_pwdft/global.h"
1516
#include "source_pw/module_pwdft/kernels/vnl_op.h"
@@ -683,8 +684,8 @@ void pseudopot_cell_vnl::init_vnl(UnitCell& cell, const ModulePW::PW_Basis* rho_
683684
}
684685

685686
#ifdef __MPI
686-
MPI_Allreduce(MPI_IN_PLACE, this->qq_nt.ptr, this->qq_nt.getSize(), MPI_DOUBLE, MPI_SUM, POOL_WORLD);
687-
MPI_Allreduce(MPI_IN_PLACE, this->qq_so.ptr, this->qq_so.getSize(), MPI_DOUBLE_COMPLEX, MPI_SUM, POOL_WORLD);
687+
Parallel_Reduce::reduce_pool(this->qq_nt.ptr, this->qq_nt.getSize());
688+
Parallel_Reduce::reduce_pool(this->qq_so.ptr, this->qq_so.getSize());
688689
#endif
689690

690691
// set the atomic specific qq_at matrices
@@ -1510,7 +1511,7 @@ void pseudopot_cell_vnl::newq(const ModuleBase::matrix& veff, const ModulePW::PW
15101511
}
15111512

15121513
#ifdef __MPI
1513-
MPI_Allreduce(MPI_IN_PLACE, deeq.ptr, deeq.getSize(), MPI_DOUBLE, MPI_SUM, POOL_WORLD);
1514+
Parallel_Reduce::reduce_pool(deeq.ptr, deeq.getSize());
15141515
#endif
15151516

15161517
delete[] qnorm;

source/source_pw/module_pwdft/elecond.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "source_base/global_variable.h"
55
#include "source_base/kernels/math_kernel_op.h"
66
#include "source_base/parallel_device.h"
7+
#include "source_base/parallel_reduce.h"
78
#include "source_estate/occupy.h"
89
#include "source_io/binstream.h"
910
#include "source_io/module_parameter/parameter.h"
@@ -93,9 +94,9 @@ void EleCond<FPTYPE, Device>::KG(const int& smear_type,
9394
jjresponse_ks(ik, nt, dt, decut, wg, velop, ct11.data(), ct12.data(), ct22.data());
9495
}
9596
#ifdef __MPI
96-
MPI_Allreduce(MPI_IN_PLACE, ct11.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
97-
MPI_Allreduce(MPI_IN_PLACE, ct12.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
98-
MPI_Allreduce(MPI_IN_PLACE, ct22.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
97+
Parallel_Reduce::reduce_all(ct11.data(), nt);
98+
Parallel_Reduce::reduce_all(ct12.data(), nt);
99+
Parallel_Reduce::reduce_all(ct22.data(), nt);
99100
#endif
100101
//------------------------------------------------------------------
101102
// Output

source/source_pw/module_stodft/sto_dos.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "sto_dos.h"
22

3+
#include "source_base/parallel_reduce.h"
34
#include "source_base/timer.h"
45
#include "source_base/tool_title.h"
56
#include "source_io/module_parameter/parameter.h"
@@ -235,8 +236,8 @@ void Sto_DOS<FPTYPE, Device>::caldos(const double sigmain, const double de, cons
235236
}
236237
#ifdef __MPI
237238
MPI_Allreduce(MPI_IN_PLACE, ks_dos.data(), ndos, MPI_DOUBLE, MPI_SUM, INT_BGROUP);
238-
MPI_Allreduce(MPI_IN_PLACE, sto_dos.data(), ndos, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
239-
MPI_Allreduce(MPI_IN_PLACE, error.data(), ndos, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
239+
Parallel_Reduce::reduce_all(sto_dos.data(), ndos);
240+
Parallel_Reduce::reduce_all(error.data(), ndos);
240241
#endif
241242
if (GlobalV::MY_RANK == 0)
242243
{

source/source_pw/module_stodft/sto_elecond.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "source_base/memory.h"
66
#include "source_base/module_container/ATen/tensor.h"
77
#include "source_base/parallel_device.h"
8+
#include "source_base/parallel_reduce.h"
89
#include "source_base/timer.h"
910
#include "source_base/vector3.h"
1011
#include "source_io/module_parameter/parameter.h"
@@ -1059,9 +1060,9 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10591060
} // ik loop
10601061
ModuleBase::timer::tick("Sto_EleCond", "kloop");
10611062
#ifdef __MPI
1062-
MPI_Allreduce(MPI_IN_PLACE, ct11.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
1063-
MPI_Allreduce(MPI_IN_PLACE, ct12.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
1064-
MPI_Allreduce(MPI_IN_PLACE, ct22.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
1063+
Parallel_Reduce::reduce_all(ct11.data(), nt);
1064+
Parallel_Reduce::reduce_all(ct12.data(), nt);
1065+
Parallel_Reduce::reduce_all(ct22.data(), nt);
10651066
#endif
10661067

10671068
//------------------------------------------------------------------

source/source_pw/module_stodft/sto_iter.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ void Stochastic_Iter<T, Device>::check_precision(const double ref, const double
249249
}
250250

251251
#ifdef __MPI
252-
MPI_Allreduce(MPI_IN_PLACE, &error, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
252+
Parallel_Reduce::reduce_all(error);
253253
#endif
254254
double relative_error = std::abs(error / ref);
255255
GlobalV::ofs_running << info << "Relative Chebyshev Precision: " << relative_error * 1e9 << "E-09" << std::endl;
@@ -473,7 +473,7 @@ double Stochastic_Iter<T, Device>::calne(elecstate::ElecState* pes)
473473
{
474474
MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD);
475475
}
476-
MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
476+
Parallel_Reduce::reduce_all(sto_ne);
477477
#endif
478478

479479
totne = KS_ne + sto_ne;
@@ -540,7 +540,7 @@ void Stochastic_Iter<T, Device>::sum_stoeband(Stochastic_WF<T, Device>& stowf,
540540
{
541541
MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD);
542542
}
543-
MPI_Allreduce(MPI_IN_PLACE, &stodemet, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
543+
Parallel_Reduce::reduce_all(stodemet);
544544
#endif
545545
pes->f_en.demet += stodemet;
546546
this->check_precision(pes->f_en.demet, 1e-4, "TS");
@@ -581,7 +581,7 @@ void Stochastic_Iter<T, Device>::sum_stoeband(Stochastic_WF<T, Device>& stowf,
581581
}
582582
}
583583
#ifdef __MPI
584-
MPI_Allreduce(MPI_IN_PLACE, &sto_eband, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
584+
Parallel_Reduce::reduce_all(sto_eband);
585585
#endif
586586
pes->f_en.eband += sto_eband;
587587
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoeband");
@@ -695,7 +695,7 @@ void Stochastic_Iter<T, Device>::cal_storho(const UnitCell& ucell,
695695
sto_ne *= ucell.omega / wfc_basis->nxyz;
696696

697697
#ifdef __MPI
698-
MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
698+
Parallel_Reduce::reduce_pool(sto_ne);
699699
#endif
700700
double factor = targetne / (KS_ne + sto_ne);
701701
if (std::abs(factor - 1) > 1e-10)

0 commit comments

Comments
 (0)