Skip to content

Commit 77062c5

Browse files
committed
1. add class MPI_Wrapper::mpi_comm
1 parent 7506cc0 commit 77062c5

File tree

8 files changed

+224
-112
lines changed

8 files changed

+224
-112
lines changed

README.cn.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ LibRI 为只包含头文件的 C++ 库,用以计算 RI 形式下的高阶方
2222
- MPI 库,用于进程间数据通讯。
2323
- BLAS 与 LAPACK 库,用于加速张量运算。
2424
> 若 BLAS 与 LAPACK 库使用 Math Kernel Library (MKL),则建议在 include 任意 LibRI 头文件前定义宏 `__MKL_RI`。LibRI 中部分函数将在编译时自动替换为 MKL 中的函数。
25-
- cereal 库,用于数据序列化与反序列化,为纯头文件库。
26-
- LibComm 库,用于进程间数据传输,为纯头文件库。
25+
- [cereal](https://uscilab.github.io/cereal/) 库,用于数据序列化与反序列化,为纯头文件库。
26+
- [LibComm](https://github.com/abacusmodeling/LibComm.git) 库,用于进程间数据传输,为纯头文件库。

include/RI/distribute/Distribute_Equally.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ namespace Distribute_Equally
2828
using TAC = std::pair<TA,std::array<Tcell,Ndim>>;
2929

3030
const std::vector<std::size_t> task_sizes(num_index, atoms.size());
31-
const std::vector<std::tuple<MPI_Comm,std::size_t,std::size_t>>
31+
const std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t, std::size_t>>
3232
comm_color_sizes = Split_Processes::split_all(mpi_comm, task_sizes);
3333

3434
std::pair<std::vector<TA>, std::vector<std::vector<TAC>>> atoms_split_list;
3535
atoms_split_list.second.resize(num_index-1);
3636

3737
if(!flag_task_repeatable)
38-
if(RI::MPI_Wrapper::mpi_get_rank(std::get<0>(comm_color_sizes.back())))
38+
if(RI::MPI_Wrapper::mpi_get_rank(std::get<0>(comm_color_sizes.back())()))
3939
return atoms_split_list;
4040

4141
atoms_split_list.first = Divide_Atoms::divide_atoms(
@@ -69,14 +69,14 @@ namespace Distribute_Equally
6969
const std::size_t task_size_period = atoms.size() * std::accumulate( period.begin(), period.end(), 1, std::multiplies<Tcell>() );
7070
std::vector<std::size_t> task_sizes(num_index, task_size_period);
7171
task_sizes[0] = atoms.size();
72-
const std::vector<std::tuple<MPI_Comm,std::size_t,std::size_t>>
72+
const std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t, std::size_t>>
7373
comm_color_sizes = Split_Processes::split_all(mpi_comm, task_sizes);
7474

7575
std::pair<std::vector<TA>, std::vector<std::vector<TAC>>> atoms_split_list;
7676
atoms_split_list.second.resize(num_index-1);
7777

7878
if(!flag_task_repeatable)
79-
if(RI::MPI_Wrapper::mpi_get_rank(std::get<0>(comm_color_sizes.back())))
79+
if(RI::MPI_Wrapper::mpi_get_rank(std::get<0>(comm_color_sizes.back())()))
8080
return atoms_split_list;
8181

8282
atoms_split_list.first = Divide_Atoms::divide_atoms(

include/RI/distribute/Split_Processes.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#pragma once
77

8+
#include "../global/MPI_Wrapper.h"
9+
810
#include <mpi.h>
911
#include <tuple>
1012
#include <vector>
@@ -16,20 +18,20 @@ namespace Split_Processes
1618
{
1719
// 将所有进程按划分为 group_size 组,每组进程数尽量相同
1820
// 返回 {本进程所在组mpi_comm, 本进程属于第几组}
19-
static std::tuple<MPI_Comm,std::size_t> split(
20-
const MPI_Comm &mpi_comm,
21+
static std::tuple<MPI_Wrapper::mpi_comm, std::size_t> split(
22+
const MPI_Comm &mc,
2123
const std::size_t &group_size);
2224

2325
// 任务数多维,所有进程多维划分组,每维分得任务数尽量相同。
2426
// 返回按第0维划分结果,{本进程所在组mpi_comm, 本进程属于第几组, 总组数}
25-
static std::tuple<MPI_Comm,std::size_t,std::size_t> split_first(
26-
const MPI_Comm &mpi_comm,
27+
static std::tuple<MPI_Wrapper::mpi_comm, std::size_t, std::size_t> split_first(
28+
const MPI_Comm &mc,
2729
const std::vector<std::size_t> &task_sizes);
2830

2931
// 任务数多维,所有进程多维划分,每维分得任务数尽量相同。
30-
// 返回按所有维划分结果,返回[0]为 {mpi_comm,0,1},返回[i+1]为 按第i维划分结果
31-
static std::vector<std::tuple<MPI_Comm,std::size_t,std::size_t>> split_all(
32-
const MPI_Comm &mpi_comm,
32+
// 返回按所有维划分结果,返回[0]为 {mc,0,1},返回[i+1]为 按第i维划分结果
33+
static std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t, std::size_t>> split_all(
34+
const MPI_Comm &mc,
3335
const std::vector<std::size_t> &task_sizes);
3436
}
3537

include/RI/distribute/Split_Processes.hpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ namespace RI
2121
namespace Split_Processes
2222
{
2323
// comm_color
24-
static std::tuple<MPI_Comm,std::size_t> split(
25-
const MPI_Comm &mpi_comm,
24+
static std::tuple<MPI_Wrapper::mpi_comm,std::size_t>
25+
split(
26+
const MPI_Comm &mc,
2627
const std::size_t &group_size)
2728
{
2829
assert(group_size>0);
29-
const std::size_t rank_mine = static_cast<std::size_t>(MPI_Wrapper::mpi_get_rank(mpi_comm));
30-
const std::size_t rank_size = static_cast<std::size_t>(MPI_Wrapper::mpi_get_size(mpi_comm));
30+
const std::size_t rank_mine = static_cast<std::size_t>(MPI_Wrapper::mpi_get_rank(mc));
31+
const std::size_t rank_size = static_cast<std::size_t>(MPI_Wrapper::mpi_get_size(mc));
3132
assert(rank_size>=group_size);
3233

3334
std::vector<std::size_t> num(group_size); // sum(num) = rank_size
@@ -50,19 +51,22 @@ namespace Split_Processes
5051
throw std::range_error(std::string(__FILE__)+" line "+std::to_string(__LINE__));
5152
}();
5253

53-
MPI_Comm mpi_comm_split;
54-
MPI_CHECK( MPI_Comm_split( mpi_comm, static_cast<int>(color_group), static_cast<int>(rank_mine), &mpi_comm_split ) );
54+
MPI_Wrapper::mpi_comm mc_split;
55+
MPI_CHECK( MPI_Comm_split(
56+
mc, static_cast<int>(color_group), static_cast<int>(rank_mine), &mc_split() ) );
57+
mc_split.flag_allocate = true;
5558

56-
return std::make_tuple(mpi_comm_split, color_group);
59+
return std::forward_as_tuple(std::move(mc_split), color_group);
5760
}
5861

5962
// comm_color_size
60-
static std::tuple<MPI_Comm,std::size_t,std::size_t> split_first(
61-
const MPI_Comm &mpi_comm,
63+
static std::tuple<MPI_Wrapper::mpi_comm, std::size_t, std::size_t>
64+
split_first(
65+
const MPI_Comm &mc,
6266
const std::vector<std::size_t> &task_sizes)
6367
{
6468
assert(task_sizes.size()>=1);
65-
const std::size_t rank_size = static_cast<std::size_t>(MPI_Wrapper::mpi_get_size(mpi_comm));
69+
const std::size_t rank_size = static_cast<std::size_t>(MPI_Wrapper::mpi_get_size(mc));
6670
const std::size_t task_product = std::accumulate(
6771
task_sizes.begin(), task_sizes.end(), std::size_t(1), std::multiplies<std::size_t>() ); // double for numerical range
6872
const double num_average =
@@ -73,19 +77,29 @@ namespace Split_Processes
7377
task_sizes[0] < num_average
7478
? 1 // if task_sizes[0]<<task_sizes[1:], then group_size<0.5. Set group_size=1
7579
: static_cast<std::size_t>(std::round(task_sizes[0]/num_average));
76-
const std::tuple<MPI_Comm,std::size_t> comm_color = split(mpi_comm, group_size);
77-
return std::make_tuple(std::get<0>(comm_color), std::get<1>(comm_color), group_size);
80+
std::tuple<MPI_Wrapper::mpi_comm, std::size_t>
81+
comm_color = split(mc, group_size);
82+
return std::make_tuple(std::move(std::get<0>(comm_color)), std::get<1>(comm_color), group_size);
7883
}
7984

8085
// vector<comm_color_size>
81-
static std::vector<std::tuple<MPI_Comm,std::size_t,std::size_t>> split_all(
82-
const MPI_Comm &mpi_comm,
86+
static std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t, std::size_t>>
87+
split_all(
88+
const MPI_Comm &mc,
8389
const std::vector<std::size_t> &task_sizes)
8490
{
85-
std::vector<std::tuple<MPI_Comm,std::size_t,std::size_t>> comm_color_sizes(task_sizes.size()+1);
86-
comm_color_sizes[0] = std::make_tuple(mpi_comm, 0, 1);
91+
std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t,std::size_t>>
92+
comm_color_sizes(task_sizes.size()+1);
93+
comm_color_sizes[0] = std::forward_as_tuple(
94+
MPI_Wrapper::mpi_comm(mc,false),
95+
0,
96+
1);
8797
for(std::size_t m=0; m<task_sizes.size(); ++m)
88-
comm_color_sizes[m+1] = split_first(std::get<0>(comm_color_sizes[m]), {task_sizes.begin()+m, task_sizes.end()});
98+
{
99+
comm_color_sizes[m+1] = split_first(
100+
std::get<0>(comm_color_sizes[m])(),
101+
{task_sizes.begin()+m, task_sizes.end()});
102+
}
89103
return comm_color_sizes;
90104
}
91105
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.06.08
4+
// ===================
5+
6+
#pragma once
7+
8+
#include <mpi.h>
9+
#include <stdexcept>
10+
#include <string>
11+
12+
#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__));
13+
14+
namespace RI
15+
{
16+
17+
namespace MPI_Wrapper
18+
{
19+
class mpi_comm
20+
{
21+
public:
22+
MPI_Comm comm;
23+
bool flag_allocate = false; // flag_allocate=true is controled by user
24+
25+
mpi_comm() = default;
26+
mpi_comm(const MPI_Comm &comm_in, const bool &flag_allocate_in): comm(comm_in), flag_allocate(flag_allocate_in) {}
27+
mpi_comm(const mpi_comm &mc_in) = delete;
28+
mpi_comm(mpi_comm &mc_in) = delete;
29+
mpi_comm(mpi_comm &&mc_in)
30+
{
31+
this->free();
32+
this->comm = mc_in.comm;
33+
this->flag_allocate = mc_in.flag_allocate;
34+
mc_in.flag_allocate = false;
35+
}
36+
mpi_comm &operator=(const mpi_comm &mc_in) = delete;
37+
mpi_comm &operator=(mpi_comm &mc_in) = delete;
38+
mpi_comm &operator=(mpi_comm &&mc_in)
39+
{
40+
this->free();
41+
this->comm = mc_in.comm;
42+
this->flag_allocate = mc_in.flag_allocate;
43+
mc_in.flag_allocate = false;
44+
return *this;
45+
}
46+
47+
~mpi_comm() { this->free(); }
48+
49+
MPI_Comm &operator()(){ return this->comm; }
50+
const MPI_Comm &operator()()const{ return this->comm; }
51+
52+
void free()
53+
{
54+
if(this->flag_allocate)
55+
{
56+
MPI_CHECK( MPI_Comm_free( &this->comm ) );
57+
this->flag_allocate = false;
58+
}
59+
}
60+
};
61+
}
62+
63+
}
64+
65+
#undef MPI_CHECK
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2022.06.02
4+
// ===================
5+
6+
#pragma once
7+
8+
#include <mpi.h>
9+
#include <complex>
10+
#include <vector>
11+
#include <string>
12+
#include <stdexcept>
13+
14+
#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__));
15+
16+
namespace RI
17+
{
18+
19+
namespace MPI_Wrapper
20+
{
21+
inline int mpi_get_rank(const MPI_Comm &mpi_comm)
22+
{
23+
int rank_mine;
24+
MPI_CHECK( MPI_Comm_rank (mpi_comm, &rank_mine) );
25+
return rank_mine;
26+
}
27+
28+
inline int mpi_get_size(const MPI_Comm &mpi_comm)
29+
{
30+
int rank_size;
31+
MPI_CHECK( MPI_Comm_size (mpi_comm, &rank_size) );
32+
return rank_size;
33+
}
34+
35+
inline MPI_Datatype mpi_get_datatype(const char &v) { return MPI_CHAR; }
36+
inline MPI_Datatype mpi_get_datatype(const short &v) { return MPI_SHORT; }
37+
inline MPI_Datatype mpi_get_datatype(const int &v) { return MPI_INT; }
38+
inline MPI_Datatype mpi_get_datatype(const long &v) { return MPI_LONG; }
39+
inline MPI_Datatype mpi_get_datatype(const long long &v) { return MPI_LONG_LONG; }
40+
inline MPI_Datatype mpi_get_datatype(const unsigned char &v) { return MPI_UNSIGNED_CHAR; }
41+
inline MPI_Datatype mpi_get_datatype(const unsigned short &v) { return MPI_UNSIGNED_SHORT; }
42+
inline MPI_Datatype mpi_get_datatype(const unsigned int &v) { return MPI_UNSIGNED; }
43+
inline MPI_Datatype mpi_get_datatype(const unsigned long &v) { return MPI_UNSIGNED_LONG; }
44+
inline MPI_Datatype mpi_get_datatype(const unsigned long long &v) { return MPI_UNSIGNED_LONG_LONG; }
45+
inline MPI_Datatype mpi_get_datatype(const float &v) { return MPI_FLOAT; }
46+
inline MPI_Datatype mpi_get_datatype(const double &v) { return MPI_DOUBLE; }
47+
inline MPI_Datatype mpi_get_datatype(const long double &v) { return MPI_LONG_DOUBLE; }
48+
inline MPI_Datatype mpi_get_datatype(const bool &v) { return MPI_CXX_BOOL; }
49+
inline MPI_Datatype mpi_get_datatype(const std::complex<float> &v) { return MPI_CXX_FLOAT_COMPLEX; }
50+
inline MPI_Datatype mpi_get_datatype(const std::complex<double> &v) { return MPI_CXX_DOUBLE_COMPLEX; }
51+
inline MPI_Datatype mpi_get_datatype(const std::complex<long double> &v) { return MPI_CXX_LONG_DOUBLE_COMPLEX; }
52+
53+
//inline int mpi_get_count(const MPI_Status &status, const MPI_Datatype &datatype)
54+
//{
55+
// int count;
56+
// MPI_CHECK( MPI_Get_count(&status, datatype, &count) );
57+
// return count;
58+
//}
59+
60+
template<typename T>
61+
inline void mpi_reduce(T &data, const MPI_Op &op, const int &root, const MPI_Comm &mpi_comm)
62+
{
63+
T data_out;
64+
MPI_CHECK( MPI_Reduce(&data, &data_out, 1, mpi_get_datatype(data), op, root, mpi_comm) );
65+
if(mpi_get_rank(mpi_comm)==root)
66+
data = data_out;
67+
}
68+
template<typename T>
69+
inline void mpi_allreduce(T &data, const MPI_Op &op, const MPI_Comm &mpi_comm)
70+
{
71+
T data_out;
72+
MPI_CHECK( MPI_Allreduce(&data, &data_out, 1, mpi_get_datatype(data), op, mpi_comm) );
73+
data = data_out;
74+
}
75+
76+
template<typename T>
77+
inline void mpi_reduce(T*const ptr, const int &count, const MPI_Op &op, const int &root, const MPI_Comm &mpi_comm)
78+
{
79+
std::vector<T> ptr_out(count);
80+
MPI_CHECK( MPI_Reduce(ptr, ptr_out.data(), count, mpi_get_datatype(*ptr), op, root, mpi_comm) );
81+
if(mpi_get_rank(mpi_comm)==root)
82+
for(std::size_t i=0; i<count; ++i)
83+
ptr[i] = ptr_out[i];
84+
}
85+
template<typename T>
86+
inline void mpi_allreduce(T*const ptr, const int &count, const MPI_Op &op, const MPI_Comm &mpi_comm)
87+
{
88+
std::vector<T> ptr_out(count);
89+
MPI_CHECK( MPI_Allreduce(ptr, ptr_out.data(), count, mpi_get_datatype(*ptr), op, mpi_comm) );
90+
for(std::size_t i=0; i<count; ++i)
91+
ptr[i] = ptr_out[i];
92+
}
93+
}
94+
95+
}
96+
97+
#undef MPI_CHECK

0 commit comments

Comments
 (0)