@@ -21,13 +21,14 @@ namespace RI
2121namespace Split_Processes
2222{
2323 // comm_color
24- static std::tuple<MPI_Comm,std::size_t > split (
25- const MPI_Comm &mpi_comm,
24+ static std::tuple<MPI_Wrapper::mpi_comm,std::size_t >
25+ split (
26+ const MPI_Comm &mc,
2627 const std::size_t &group_size)
2728 {
2829 assert (group_size>0 );
29- const std::size_t rank_mine = static_cast <std::size_t >(MPI_Wrapper::mpi_get_rank (mpi_comm ));
30- const std::size_t rank_size = static_cast <std::size_t >(MPI_Wrapper::mpi_get_size (mpi_comm ));
30+ const std::size_t rank_mine = static_cast <std::size_t >(MPI_Wrapper::mpi_get_rank (mc ));
31+ const std::size_t rank_size = static_cast <std::size_t >(MPI_Wrapper::mpi_get_size (mc ));
3132 assert (rank_size>=group_size);
3233
3334 std::vector<std::size_t > num (group_size); // sum(num) = rank_size
@@ -50,19 +51,22 @@ namespace Split_Processes
5051 throw std::range_error (std::string (__FILE__)+" line " +std::to_string (__LINE__));
5152 }();
5253
53- MPI_Comm mpi_comm_split;
54- MPI_CHECK ( MPI_Comm_split ( mpi_comm, static_cast <int >(color_group), static_cast <int >(rank_mine), &mpi_comm_split ) );
54+ MPI_Wrapper::mpi_comm mc_split;
55+ MPI_CHECK ( MPI_Comm_split (
56+ mc, static_cast <int >(color_group), static_cast <int >(rank_mine), &mc_split () ) );
57+ mc_split.flag_allocate = true ;
5558
56- return std::make_tuple (mpi_comm_split , color_group);
59+ return std::forward_as_tuple ( std::move (mc_split) , color_group);
5760 }
5861
5962 // comm_color_size
60- static std::tuple<MPI_Comm,std::size_t ,std::size_t > split_first (
61- const MPI_Comm &mpi_comm,
63+ static std::tuple<MPI_Wrapper::mpi_comm, std::size_t , std::size_t >
64+ split_first (
65+ const MPI_Comm &mc,
6266 const std::vector<std::size_t > &task_sizes)
6367 {
6468 assert (task_sizes.size ()>=1 );
65- const std::size_t rank_size = static_cast <std::size_t >(MPI_Wrapper::mpi_get_size (mpi_comm ));
69+ const std::size_t rank_size = static_cast <std::size_t >(MPI_Wrapper::mpi_get_size (mc ));
6670 const std::size_t task_product = std::accumulate (
6771 task_sizes.begin (), task_sizes.end (), std::size_t (1 ), std::multiplies<std::size_t >() ); // double for numerical range
6872 const double num_average =
@@ -73,19 +77,29 @@ namespace Split_Processes
7377 task_sizes[0 ] < num_average
7478 ? 1 // if task_sizes[0]<<task_sizes[1:], then group_size<0.5. Set group_size=1
7579 : static_cast <std::size_t >(std::round (task_sizes[0 ]/num_average));
76- const std::tuple<MPI_Comm,std::size_t > comm_color = split (mpi_comm, group_size);
77- return std::make_tuple (std::get<0 >(comm_color), std::get<1 >(comm_color), group_size);
80+ std::tuple<MPI_Wrapper::mpi_comm, std::size_t >
81+ comm_color = split (mc, group_size);
82+ return std::make_tuple (std::move (std::get<0 >(comm_color)), std::get<1 >(comm_color), group_size);
7883 }
7984
8085 // vector<comm_color_size>
81- static std::vector<std::tuple<MPI_Comm,std::size_t ,std::size_t >> split_all (
82- const MPI_Comm &mpi_comm,
86+ static std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t , std::size_t >>
87+ split_all (
88+ const MPI_Comm &mc,
8389 const std::vector<std::size_t > &task_sizes)
8490 {
85- std::vector<std::tuple<MPI_Comm,std::size_t ,std::size_t >> comm_color_sizes (task_sizes.size ()+1 );
86- comm_color_sizes[0 ] = std::make_tuple (mpi_comm, 0 , 1 );
91+ std::vector<std::tuple<MPI_Wrapper::mpi_comm, std::size_t ,std::size_t >>
92+ comm_color_sizes (task_sizes.size ()+1 );
93+ comm_color_sizes[0 ] = std::forward_as_tuple (
94+ MPI_Wrapper::mpi_comm (mc,false ),
95+ 0 ,
96+ 1 );
8797 for (std::size_t m=0 ; m<task_sizes.size (); ++m)
88- comm_color_sizes[m+1 ] = split_first (std::get<0 >(comm_color_sizes[m]), {task_sizes.begin ()+m, task_sizes.end ()});
98+ {
99+ comm_color_sizes[m+1 ] = split_first (
100+ std::get<0 >(comm_color_sizes[m])(),
101+ {task_sizes.begin ()+m, task_sizes.end ()});
102+ }
89103 return comm_color_sizes;
90104 }
91105}
0 commit comments