Skip to content

Commit 6211f0d

Browse files
committed
Refactor local partitioning routines
1 parent e3957b9 commit 6211f0d

File tree

2 files changed

+111
-231
lines changed

2 files changed

+111
-231
lines changed

kdpart.h

Lines changed: 102 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,18 @@ struct PartTreeStorage {
293293
walk(std::forward<Pred>(p), std::forward<Func>(f), 0);
294294
}
295295

296+
297+
/* Depth-first traversal the subtree of node i (including i itself)
298+
* calling function f on each node.
299+
* @param f bool function with a single parameter of type
300+
* const_node_access_type. Return value determines if "walk"
301+
* further descends into subtree.
302+
*/
303+
template <typename Func>
304+
void walk_post(Func &&f) {
305+
walk_post(std::forward<Func>(f), 0);
306+
}
307+
296308
/** Invalidates the information in the subtree rooted at "root".
297309
* Used for local repartitioning.
298310
*/
@@ -366,6 +378,22 @@ struct PartTreeStorage {
366378
}
367379
}
368380

381+
382+
/* Depth-first traversal the subtree of node i (including i itself)
383+
* calling function f on each node.
384+
* @param f bool function with a single parameter of type
385+
* const_node_access_type. Return value determines if "walk"
386+
* further descends into subtree.
387+
*/
388+
template <typename Func>
389+
void walk_post(Func f, int i) {
390+
const bool b = f(node(i));
391+
if (b && inner[i]) {
392+
walk_post(f, 2 * i + 1);
393+
walk_post(f, 2 * i + 2);
394+
}
395+
}
396+
369397
/// Descend direction for find(Chooser)
370398
enum class Descend {
371399
Left, Right
@@ -826,20 +854,8 @@ struct mpi_helper<std::array<int, 3>>
826854
}
827855
};
828856

829-
}
830-
831-
/** Repartitions a kd tree inline by means of local communication of the weights
832-
* and local computation of the new splits.
833-
*
834-
* Repartitioning is done between 2 or 3 processes that share the same
835-
* level 1 or level 2 subtree.
836-
*
837-
* @returns Parameter "s"
838-
*
839-
* @see repart_parttree_par
840-
*/
841-
template <LinearizeFunc LinearizeFn = linearize>
842-
PartTreeStorage& repart_parttree_par_local(PartTreeStorage& s, MPI_Comm comm, const std::vector<double>& cellweights)
857+
template <LinearizeFunc LinearizeFn = linearize, typename RepartRootFunc, typename IsRepartRootPred>
858+
PartTreeStorage& _repart_parttree_par_local_impl(PartTreeStorage& s, MPI_Comm comm, const std::vector<double>& cellweights, RepartRootFunc &&get_subtree_repartition_root, IsRepartRootPred &&is_repartition_subtree_root)
843859
{
844860
int size;
845861
MPI_Comm_size(comm, &size);
@@ -854,35 +870,35 @@ PartTreeStorage& repart_parttree_par_local(PartTreeStorage& s, MPI_Comm comm, co
854870
int rank;
855871
MPI_Comm_rank(comm, &rank);
856872
auto my_leaf = s.node_of_rank(rank);
857-
auto my_limb_root = my_leaf.find_limbend_root();
873+
auto my_subtree_root = get_subtree_repartition_root(my_leaf);
858874

859875
/* Get all ranks that have subdomains participating in this limb end */
860-
std::vector<int> limb_neighbors;
861-
limb_neighbors.reserve(3);
862-
for (int r = my_limb_root.pstart(); r < my_limb_root.pend() + 1; ++r) {
863-
limb_neighbors.push_back(r);
876+
std::vector<int> subtree_neighbors;
877+
subtree_neighbors.reserve(3);
878+
for (int r = my_subtree_root.pstart(); r < my_subtree_root.pend() + 1; ++r) {
879+
subtree_neighbors.push_back(r);
864880
}
865881

866-
/* Distribute weights to "limb_neighbors" */
882+
/* Distribute weights to "subtree_neighbors" */
867883

868884
// A group of processes participating in a limb is identified by the smallest
869885
// rank in this limb end.
870-
const int limb_group = *std::min_element(limb_neighbors.begin(), limb_neighbors.end());
886+
const int subtree_group_id = my_subtree_root.pstart();
871887
MPI_Comm neighcomm;
872-
MPI_Comm_split(comm, limb_group, rank, &neighcomm);
888+
MPI_Comm_split(comm, subtree_group_id, rank, &neighcomm);
873889
util::GlobalVector<double> neighbor_load(neighcomm, cellweights);
874890

875891
// Setup a mapping from ranks relative to "comm" to ranks relative to "neighcomm"
876892
// Since we chose the rank relative to "comm" as key in MPI_Comm_split,
877893
// we could also do this manually:
878-
// rank_to_neighrank[min(limb_neighbors)] = 0
879-
// rank_to_neighrank[max(limb_neighbors)] = limb_neighbors.size() - 1;
894+
// rank_to_neighrank[min(subtree_neighbors)] = 0
895+
// rank_to_neighrank[max(subtree_neighbors)] = subtree_neighbors.size() - 1;
880896
// And if it is a limbend of size 3 assign rank "1" to the middle element.
881897
// But this is less error prone if we ever choose to change the Comm_split.
882-
const std::map<int, int> rank_to_neighrank = map_ranks_to_comm(comm, limb_neighbors, neighcomm);
898+
const std::map<int, int> rank_to_neighrank = map_ranks_to_comm(comm, subtree_neighbors, neighcomm);
883899

884900
using cell_type = std::array<int, 3>;
885-
auto neighbor_load_func = [&old_part, &neighbor_load, &rank_to_neighrank, &limb_neighbors](const cell_type& c){
901+
auto neighbor_load_func = [&old_part, &neighbor_load, &rank_to_neighrank, &subtree_neighbors](const cell_type& c){
886902
auto n = old_part.node_of_cell(c);
887903

888904
// Transform c to process ("rank") local coordinates
@@ -895,7 +911,7 @@ PartTreeStorage& repart_parttree_par_local(PartTreeStorage& s, MPI_Comm comm, co
895911
const auto i = LinearizeFn(loc_c, loc_box);
896912

897913
// Map rank (relative to "comm") to "neighcomm".
898-
assert(std::find(limb_neighbors.begin(), limb_neighbors.end(), n.rank()) != limb_neighbors.end());
914+
assert(std::find(subtree_neighbors.begin(), subtree_neighbors.end(), n.rank()) != subtree_neighbors.end());
899915
const auto rank = rank_to_neighrank.at(n.rank());
900916

901917
assert(neighbor_load.size(rank) > i);
@@ -908,146 +924,13 @@ PartTreeStorage& repart_parttree_par_local(PartTreeStorage& s, MPI_Comm comm, co
908924
return quality_splitting(codimload(split_dir, lu, ro), nproc, nproc_left);
909925
};
910926

927+
const bool is_responsible_process = rank == subtree_group_id;
911928
/* Re-partition limb ends
912929
* The process with smallest rank number is responsible for re-creating
913930
* this limb
914931
*/
915-
if (rank == my_limb_root.pstart()) {
916-
auto nproc1 = my_limb_root.child1().nproc();
917-
auto nproc2 = my_limb_root.child2().nproc();
918-
919-
// Passing "nproc1" as last parameter ensures that "nproc1" and
920-
// "nproc2" do not change.
921-
impl::split_node(my_limb_root, splitfunc, nproc1);
922-
// Split_node avoids setting "inner", so do it manually
923-
// This is also the reason why we use ".nproc()" to distinguish between
924-
// nodes and not ".is_inner()"!
925-
auto child1 = my_limb_root.child1(), child2 = my_limb_root.child2();
926-
child1.inner() = 0;
927-
child2.inner() = 0;
928-
929-
assert(nproc1 == child1.nproc());
930-
assert(nproc2 == child2.nproc());
931-
932-
// In case of a 3-leaf limb split a child
933-
if (child1.nproc() > 1) {
934-
impl::split_node(child1, splitfunc);
935-
assert(child2.nproc() == 1);
936-
} else if (child2.nproc() > 1) {
937-
impl::split_node(child2, splitfunc);
938-
assert(child1.nproc() == 1);
939-
} else {
940-
assert(limb_neighbors.size() == 2);
941-
}
942-
}
943-
944-
/* Invalidate all limb ends
945-
* Limb ends are invalidated on all processes not participating as well as on
946-
* all participating processes BUT the one that performed the new split
947-
*/
948-
s.walk([my_limb_root, &s, rank](auto node){
949-
if (node.is_limbend() && node != my_limb_root && rank != my_limb_root.pstart()) {
950-
s.invalidate_subtree(node);
951-
// Stop "walk" from descending into nothing
952-
node.inner() = 0;
953-
}
954-
});
955-
956-
/* Re-distribute all changes
957-
* Use an Allreduce operation with MPI_MAX as operation.
958-
* This works because we set all fields of all limb ends to "0" above.
959-
*/
960-
std::vector<MPI_Request> reqs{};
961-
reqs.reserve(8);
962-
s.apply_to_data_vectors([comm, &reqs](auto &vec){
963-
using value_type = typename std::remove_reference<decltype(vec)>::type::value_type;
964-
reqs.push_back(mpi_helper<value_type>::iallreduce(vec, MPI_MAX, comm));
965-
});
966-
MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
967-
968-
MPI_Comm_free(&neighcomm);
969-
return s;
970-
}
971-
972-
template <LinearizeFunc LinearizeFn = linearize>
973-
PartTreeStorage& repart_parttree_par_local_top(PartTreeStorage& s, MPI_Comm comm, const std::vector<double>& cellweights, int depth)
974-
{
975-
int size;
976-
MPI_Comm_size(comm, &size);
977-
978-
// Nothing to do.
979-
if (size == 1)
980-
return s;
981-
982-
PartTreeStorage old_part = s; // Tree corresponding to "cellweights".
983-
984-
/* Find own limb end */
985-
int rank;
986-
MPI_Comm_rank(comm, &rank);
987-
auto my_leaf = s.node_of_rank(rank);
988-
auto my_limb_root = my_leaf.find_root_of_subtree(depth);
989-
990-
/* Get all ranks that have subdomains participating in this limb end */
991-
std::vector<int> limb_neighbors;
992-
limb_neighbors.reserve(3);
993-
for (int r = my_limb_root.pstart(); r < my_limb_root.pend() + 1; ++r) {
994-
limb_neighbors.push_back(r);
995-
}
996-
997-
/* Distribute weights to "limb_neighbors" */
998-
999-
// A group of processes participating in a limb is identified by the smallest
1000-
// rank in this limb end.
1001-
const int limb_group = *std::min_element(limb_neighbors.begin(), limb_neighbors.end());
1002-
MPI_Comm neighcomm;
1003-
MPI_Comm_split(comm, limb_group, rank, &neighcomm);
1004-
util::GlobalVector<double> neighbor_load(neighcomm, cellweights);
1005-
1006-
// Setup a mapping from ranks relative to "comm" to ranks relative to "neighcomm"
1007-
// Since we chose the rank relative to "comm" as key in MPI_Comm_split,
1008-
// we could also do this manually:
1009-
// rank_to_neighrank[min(limb_neighbors)] = 0
1010-
// rank_to_neighrank[max(limb_neighbors)] = limb_neighbors.size() - 1;
1011-
// And if it is a limbend of size 3 assign rank "1" to the middle element.
1012-
// But this is less error prone if we ever choose to change the Comm_split.
1013-
const std::map<int, int> rank_to_neighrank = map_ranks_to_comm(comm, limb_neighbors, neighcomm);
1014-
1015-
using cell_type = std::array<int, 3>;
1016-
auto neighbor_load_func = [&old_part, &neighbor_load, &rank_to_neighrank, &limb_neighbors](const cell_type& c){
1017-
auto n = old_part.node_of_cell(c);
1018-
1019-
// Transform c to process ("rank") local coordinates
1020-
cell_type loc_c, loc_box;
1021-
for (auto i = 0; i < 3; ++i) {
1022-
loc_c[i] = c[i] - n.lu()[i];
1023-
loc_box[i] = n.ro()[i] - n.lu()[i];
1024-
}
1025-
1026-
const auto i = LinearizeFn(loc_c, loc_box);
1027-
1028-
// Map rank (relative to "comm") to "neighcomm".
1029-
assert(std::find(limb_neighbors.begin(), limb_neighbors.end(), n.rank()) != limb_neighbors.end());
1030-
const auto rank = rank_to_neighrank.at(n.rank());
1031-
1032-
assert(neighbor_load.size(rank) > i);
1033-
return neighbor_load(rank, i);
1034-
};
1035-
1036-
auto codimload = util::CodimSum<decltype(neighbor_load_func)>(neighbor_load_func);
1037-
1038-
auto splitfunc = [&codimload](int split_dir, std::array<int, 3> lu, std::array<int, 3> ro, int nproc, int nproc_left) {
1039-
return quality_splitting(codimload(split_dir, lu, ro), nproc, nproc_left);
1040-
};
1041-
1042932
int new_max_depth = 0;
1043-
/* Re-partition limb ends
1044-
* The process with smallest rank number is responsible for re-creating
1045-
* this limb
1046-
*/
1047-
if (rank == my_limb_root.pstart() && limb_neighbors.size() > 1) {
1048-
std::cout << "pre re-split:" << std::endl;
1049-
std::cout << "child1.nproc " << my_limb_root.child1().nproc() << std::endl;
1050-
std::cout << "child2.nproc " << my_limb_root.child2().nproc() << std::endl;
933+
if (is_responsible_process && subtree_neighbors.size() > 1) {
1051934
s.walk_subtree([&s, &splitfunc, &new_max_depth](auto node) {
1052935
// Need to split the node further?
1053936
if (node.nproc() > 1) {
@@ -1059,23 +942,26 @@ PartTreeStorage& repart_parttree_par_local_top(PartTreeStorage& s, MPI_Comm comm
1059942
} else {
1060943
new_max_depth = std::max(new_max_depth, node.depth());
1061944
}
1062-
}, my_limb_root);
1063-
std::cout << "after re-split:" << std::endl;
1064-
std::cout << "child1.nproc " << my_limb_root.child1().nproc() << std::endl;
1065-
std::cout << "child2.nproc " << my_limb_root.child2().nproc() << std::endl;
945+
}, my_subtree_root);
1066946
}
1067947

1068948
/* Invalidate all limb ends
1069949
* Limb ends are invalidated on all processes not participating as well as on
1070950
* all participating processes BUT the one that performed the new split
1071951
*/
1072-
s.walk([my_limb_root, &s, rank, depth](auto node){
1073-
if (node.depth() == depth) {
1074-
if (node != my_limb_root || rank != my_limb_root.pstart()) {
952+
s.walk_post([my_subtree_root, &s, is_repartition_subtree_root, is_responsible_process](auto node){
953+
if (is_repartition_subtree_root(node)) {
954+
if (node != my_subtree_root || !is_responsible_process) {
1075955
s.invalidate_subtree(node);
1076956
// Stop "walk" from descending into nothing
1077957
node.inner() = 0;
1078958
}
959+
// Don't descend any further. For the limb end version of
960+
// repartitioning, subtrees of "node" is_repartitin_subtree_root
961+
// can also evaluate to true.
962+
return false;
963+
} else {
964+
return true;
1079965
}
1080966
});
1081967

@@ -1099,6 +985,53 @@ PartTreeStorage& repart_parttree_par_local_top(PartTreeStorage& s, MPI_Comm comm
1099985
return s;
1100986
}
1101987

988+
}
989+
990+
/** Repartitions a kd tree inline by means of local communication of the weights
991+
* and local computation of the new splits.
992+
*
993+
* Repartitioning is done individually in each subtree rooted at the nodes
994+
* of depth "depth" in the tree.
995+
* In a regular binary tree 2^depth subgroups will be formed that repartition
996+
* individually. This makes s.root().nproc() / (1<<depth) many processes per
997+
* subgroup that will, internally, communicate their weights.
998+
*
999+
* @param depth Depth of the tree on which partitioning will be performed in subgroups
1000+
* @returns Parameter "s"
1001+
*
1002+
* @see repart_parttree_par
1003+
*/
1004+
template <LinearizeFunc LinearizeFn = linearize>
1005+
PartTreeStorage& repart_parttree_par_local_top(PartTreeStorage& s, MPI_Comm comm, const std::vector<double>& cellweights, int depth)
1006+
{
1007+
return _repart_parttree_par_local_impl(s, comm, cellweights, [depth](auto leaf){
1008+
return leaf.find_root_of_subtree(depth);
1009+
}, [depth](auto node) {
1010+
return node.depth() == depth;
1011+
});
1012+
}
1013+
1014+
1015+
/** Repartitions a kd tree inline by means of local communication of the weights
1016+
* and local computation of the new splits.
1017+
*
1018+
* Repartitioning is done between 2 or 3 processes that share the same
1019+
* level 1 or level 2 subtree.
1020+
*
1021+
* @returns Parameter "s"
1022+
*
1023+
* @see repart_parttree_par
1024+
*/
1025+
template <LinearizeFunc LinearizeFn = linearize>
1026+
PartTreeStorage& repart_parttree_par_local(PartTreeStorage& s, MPI_Comm comm, const std::vector<double>& cellweights)
1027+
{
1028+
return _repart_parttree_par_local_impl(s, comm, cellweights, [](auto leaf){
1029+
return leaf.find_limbend_root();
1030+
}, [](auto node){
1031+
return node.is_limbend();
1032+
});
1033+
}
1034+
11021035

11031036
/** Returns a tree with "size" subdomain.
11041037
*

0 commit comments

Comments
 (0)