Skip to content

Commit e308494

Browse files
committed
Refactor: complete MPI encapsulation and fix OMP pragmas for better compatibility
1 parent a2e85e4 commit e308494

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

source/source_basis/module_pw/test/depend_mock.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "mpi.h"
33
#endif
44
#include "depend_mock.h"
5+
#include <complex>
56

67
namespace GlobalV
78
{
@@ -12,13 +13,85 @@ MPI_Comm POOL_WORLD;
1213
namespace Parallel_Reduce
1314
{
1415
template<typename T> void reduce_all(T& object) { return; };
16+
template<typename T> void reduce_all(T* object, const int n) { return; };
1517
template<typename T> void reduce_pool(T& object) { return; };
18+
template<typename T> void reduce_pool(T* object, const int n) { return; };
1619

20+
template<>
21+
void reduce_all<int>(int& object) { return; };
22+
template<>
23+
void reduce_all<long long>(long long& object) { return; };
1724
template<>
1825
void reduce_all<double>(double& object) { return; };
1926
template<>
20-
void reduce_pool<double>(double& object) { return; };
27+
void reduce_all<float>(float& object) { return; };
28+
template<>
29+
void reduce_all<std::complex<double>>(std::complex<double>& object) { return; };
30+
template<>
31+
void reduce_all<std::complex<float>>(std::complex<float>& object) { return; };
32+
33+
template<>
34+
void reduce_all<int>(int* object, const int n) { return; };
35+
template<>
36+
void reduce_all<long long>(long long* object, const int n) { return; };
37+
template<>
38+
void reduce_all<double>(double* object, const int n) { return; };
39+
template<>
40+
void reduce_all<std::complex<double>>(std::complex<double>* object, const int n) { return; };
41+
template<>
42+
void reduce_all<std::complex<float>>(std::complex<float>* object, const int n) { return; };
43+
2144
template<>
2245
void reduce_pool<float>(float& object) { return; };
46+
template<>
47+
void reduce_pool<double>(double& object) { return; };
48+
template<>
49+
void reduce_pool<std::complex<double>>(std::complex<double>& object) { return; };
50+
51+
template<>
52+
void reduce_pool<int>(int* object, const int n) { return; };
53+
template<>
54+
void reduce_pool<double>(double* object, const int n) { return; };
55+
template<>
56+
void reduce_pool<std::complex<float>>(std::complex<float>* object, const int n) { return; };
57+
template<>
58+
void reduce_pool<std::complex<double>>(std::complex<double>* object, const int n) { return; };
59+
60+
void reduce_or_all(bool& object) { return; };
61+
62+
template <typename T>
63+
void reduce_max_all(T& object) { return; };
64+
template<> void reduce_max_all<double>(double& object) { return; };
65+
template<> void reduce_max_all<float>(float& object) { return; };
66+
template<> void reduce_max_all<int>(int& object) { return; };
67+
68+
template <typename T>
69+
void reduce_min_all(T& object) { return; };
70+
template<> void reduce_min_all<double>(double& object) { return; };
71+
template<> void reduce_min_all<float>(float& object) { return; };
72+
template<> void reduce_min_all<int>(int& object) { return; };
73+
74+
void reduce_max_pool(int* object, const int n) { return; };
75+
void reduce_min_pool(double& object) { return; };
76+
77+
void reduce_or_bp(bool& object) { return; };
78+
79+
void reduce_double_bgroup(double& object) { return; };
80+
void reduce_double_bgroup(double* object, const int n) { return; };
81+
82+
void reduce_double_bp(double& object) { return; };
83+
void reduce_double_bp(double* object, const int n) { return; };
84+
85+
void reduce_double_kp(double* object, const int n) { return; };
86+
87+
void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object) { return; };
88+
void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n) { return; };
89+
90+
void gather_min_int_all(const int& nproc, int& v) { return; };
91+
void gather_max_double_all(const int& nproc, double& v) { return; };
92+
void gather_min_double_all(const int& nproc, double& v) { return; };
93+
void gather_max_double_pool(const int& nproc_in_pool, double& v) { return; };
94+
void gather_min_double_pool(const int& nproc_in_pool, double& v) { return; };
95+
void gather_int_all(int& v, int* all) { return; };
2396
}
2497
#endif

source/source_lcao/module_lr/utils/lr_util.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <algorithm>
55
#include "source_cell/unitcell.h"
66
#include "source_base/constants.h"
7+
#include "source_base/parallel_reduce.h"
8+
#include "source_base/parallel_device.h"
79
#include "source_hamilt/module_xc/xc_functional.h"
810
namespace LR_Util
911
{
@@ -172,7 +174,7 @@ namespace LR_Util
172174
}
173175

174176
//reduce to root
175-
MPI_Allreduce(MPI_IN_PLACE, fullmat, global_nrow * global_ncol, get_mpi_datatype(), MPI_SUM, pv.comm());
177+
Parallel_Common::reduce_dev<T, base_device::DEVICE_CPU>(fullmat, global_nrow * global_ncol, pv.comm());
176178
};
177179
#endif
178180

source/source_lcao/module_operator_lcao/op_exx_lcao.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "source_io/module_parameter/parameter.h"
77
#include "source_lcao/module_ri/RI_2D_Comm.h"
88
#include "source_pw/module_pwdft/global.h"
9+
#include "source_base/parallel_reduce.h"
910
#include "source_hamilt/module_xc/xc_functional.h"
1011
#include "source_io/restart_exx_csr.h"
1112

@@ -245,7 +246,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
245246
// Add MPI communication to synchronize all_exist across processes
246247
#ifdef __MPI
247248
// don't read in any files if one of the processes doesn't have it
248-
MPI_Allreduce(MPI_IN_PLACE, &all_exist, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
249+
Parallel_Reduce::reduce_min_all(all_exist);
249250
#endif
250251
if (all_exist)
251252
{
@@ -264,7 +265,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
264265
std::ifstream ifs(restart_HR_path_cereal, std::ios::binary);
265266
int all_exist_cereal = ifs ? 1 : 0;
266267
#ifdef __MPI
267-
MPI_Allreduce(MPI_IN_PLACE, &all_exist_cereal, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
268+
Parallel_Reduce::reduce_min_all(all_exist_cereal);
268269
#endif
269270
if (!all_exist_cereal)
270271
{

source/source_pw/module_pwdft/setup_pwwfc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "source_pw/module_pwdft/setup_pwwfc.h" // pw_wfc
22
#include "source_base/parallel_comm.h" // POOL_WORLD
3+
#include "source_base/parallel_reduce.h"
34
#include "source_io/print_info.h" // print information
45

56
void pw::teardown_pwwfc(ModulePW::PW_Basis_K* &pw_wfc)
@@ -56,7 +57,7 @@ void pw::setup_pwwfc(const Input_para& inp,
5657
#ifdef __MPI
5758
if (inp.pw_seed > 0)
5859
{
59-
MPI_Allreduce(MPI_IN_PLACE, &pw_wfc->ggecut, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
60+
Parallel_Reduce::reduce_max_all(pw_wfc->ggecut);
6061
}
6162
// qianrui add 2021-8-13 to make different kpar parameters can get the same
6263
// results

0 commit comments

Comments
 (0)