Skip to content

Commit f3bf0d7

Browse files
Marsella8Pietro Max Marsellalockshaw
authored
Balanced sp decomposition binary tree (#1604)
* added balanced_binary_sp_tree utilities * fmt * fmt * import fix * fmt * fix * PR fixes * Format --------- Co-authored-by: Pietro Max Marsella <[email protected]> Co-authored-by: Colin Unger <[email protected]>
1 parent 97aec04 commit f3bf0d7

File tree

16 files changed

+539
-10
lines changed

16 files changed

+539
-10
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H
2+
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H
3+
4+
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h"
5+
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h"
6+
#include <optional>
7+
8+
namespace FlexFlow {
9+
10+
std::optional<PCGBinarySPDecomposition>
11+
get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &);
12+
13+
} // namespace FlexFlow
14+
15+
#endif

lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ std::unordered_set<BinaryTreePath>
3636
find_paths_to_leaf(PCGBinarySPDecomposition const &,
3737
parallel_layer_guid_t const &);
3838

39+
PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree(
40+
BinarySPDecompositionTree const &spd_tree);
41+
3942
std::unordered_map<BinaryTreePath, parallel_layer_guid_t>
4043
pcg_sp_tree_get_path_to_leaf_map(PCGBinarySPDecomposition const &);
4144

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h"
2+
#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h"
3+
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
4+
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h"
5+
#include <optional>
6+
7+
namespace FlexFlow {
8+
9+
std::optional<PCGBinarySPDecomposition>
10+
get_pcg_balanced_binary_sp_decomposition(
11+
ParallelComputationGraph const &pcg) {
12+
13+
std::optional<SeriesParallelDecomposition> spd =
14+
get_pcg_series_parallel_decomposition(pcg);
15+
16+
if (!spd.has_value()) {
17+
return std::nullopt;
18+
}
19+
20+
return pcg_binary_sp_decomposition_from_binary_sp_tree(
21+
balanced_binary_sp_tree_from_nary(spd.value()));
22+
}
23+
24+
} // namespace FlexFlow

lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
2+
#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h"
23
#include "compiler/series_parallel/pcg/pcg_binary_series_split.h"
34
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h"
45
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h"
@@ -68,10 +69,7 @@ BinarySPDecompositionTree
6869
},
6970
[](PCGBinaryParallelSplit const &parallel) -> BinarySPDecompositionTree {
7071
return BinarySPDecompositionTree{
71-
BinaryParallelSplit{
72-
binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()),
73-
binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()),
74-
},
72+
binary_parallel_split_from_pcg_parallel_split(parallel),
7573
};
7674
},
7775
[](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree {
@@ -82,9 +80,35 @@ BinarySPDecompositionTree
8280
});
8381
}
8482

85-
std::optional<PCGBinarySPDecomposition>
86-
get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) {
87-
NOT_IMPLEMENTED();
83+
PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree(
84+
BinarySPDecompositionTree const &spd_tree) {
85+
return spd_tree.visit<PCGBinarySPDecomposition>(overload{
86+
[](BinarySeriesSplit const &series) -> PCGBinarySPDecomposition {
87+
return PCGBinarySPDecomposition{
88+
PCGBinarySeriesSplit{
89+
pcg_binary_sp_decomposition_from_binary_sp_tree(
90+
series.get_left_child()),
91+
pcg_binary_sp_decomposition_from_binary_sp_tree(
92+
series.get_right_child()),
93+
},
94+
};
95+
},
96+
[](BinaryParallelSplit const &parallel) -> PCGBinarySPDecomposition {
97+
return PCGBinarySPDecomposition{
98+
PCGBinaryParallelSplit{
99+
pcg_binary_sp_decomposition_from_binary_sp_tree(
100+
parallel.get_left_child()),
101+
pcg_binary_sp_decomposition_from_binary_sp_tree(
102+
parallel.get_right_child()),
103+
},
104+
};
105+
},
106+
[](Node const &node) -> PCGBinarySPDecomposition {
107+
return PCGBinarySPDecomposition{
108+
parallel_layer_guid_t{node},
109+
};
110+
},
111+
});
88112
}
89113

90114
std::unordered_multiset<parallel_layer_guid_t>
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
2+
#include "test/utils/doctest/fmt/unordered_multiset.h"
3+
#include "test/utils/rapidcheck.h"
4+
#include <doctest/doctest.h>
5+
6+
using namespace ::FlexFlow;
7+
8+
TEST_SUITE(FF_TEST_SUITE) {
9+
TEST_CASE("pcg_binary_sp_decomposition_from_binary_sp_tree") {
10+
Node n1 = Node{1};
11+
Node n2 = Node{2};
12+
Node n3 = Node{3};
13+
14+
auto make_binary_series_split = [](BinarySPDecompositionTree const &lhs,
15+
BinarySPDecompositionTree const &rhs) {
16+
return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}};
17+
};
18+
19+
auto make_binary_parallel_split = [](BinarySPDecompositionTree const &lhs,
20+
BinarySPDecompositionTree const &rhs) {
21+
return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}};
22+
};
23+
24+
auto make_binary_leaf = [](Node const &n) {
25+
return BinarySPDecompositionTree{n};
26+
};
27+
28+
auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs,
29+
PCGBinarySPDecomposition const &rhs) {
30+
return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}};
31+
};
32+
33+
auto make_pcg_parallel_split = [](PCGBinarySPDecomposition const &lhs,
34+
PCGBinarySPDecomposition const &rhs) {
35+
return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}};
36+
};
37+
38+
auto make_pcg_leaf = [](Node const &n) {
39+
return PCGBinarySPDecomposition{parallel_layer_guid_t{n}};
40+
};
41+
42+
SUBCASE("single node") {
43+
BinarySPDecompositionTree input = make_binary_leaf(n1);
44+
45+
PCGBinarySPDecomposition result =
46+
pcg_binary_sp_decomposition_from_binary_sp_tree(input);
47+
48+
PCGBinarySPDecomposition expected = make_pcg_leaf(n1);
49+
50+
CHECK(result == expected);
51+
}
52+
53+
SUBCASE("series split") {
54+
BinarySPDecompositionTree input =
55+
make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2));
56+
57+
PCGBinarySPDecomposition result =
58+
pcg_binary_sp_decomposition_from_binary_sp_tree(input);
59+
60+
PCGBinarySPDecomposition expected =
61+
make_pcg_series_split(make_pcg_leaf(n1), make_pcg_leaf(n2));
62+
63+
CHECK(result == expected);
64+
}
65+
66+
SUBCASE("parallel split") {
67+
BinarySPDecompositionTree input = make_binary_parallel_split(
68+
make_binary_leaf(n1), make_binary_leaf(n2));
69+
70+
PCGBinarySPDecomposition result =
71+
pcg_binary_sp_decomposition_from_binary_sp_tree(input);
72+
73+
PCGBinarySPDecomposition expected =
74+
make_pcg_parallel_split(make_pcg_leaf(n1), make_pcg_leaf(n2));
75+
76+
CHECK(result == expected);
77+
}
78+
79+
SUBCASE("bijectiveness") {
80+
BinarySPDecompositionTree original = make_binary_parallel_split(
81+
make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2)),
82+
make_binary_leaf(n3));
83+
84+
PCGBinarySPDecomposition pcg_tree =
85+
pcg_binary_sp_decomposition_from_binary_sp_tree(original);
86+
BinarySPDecompositionTree converted =
87+
binary_sp_tree_from_pcg_sp_tree(pcg_tree);
88+
89+
CHECK(original == converted);
90+
}
91+
}
92+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_TREE_HEIGHT_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_TREE_HEIGHT_H
3+
4+
#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h"
5+
#include "utils/full_binary_tree/visit.h"
6+
#include "utils/nonnegative_int/nonnegative_int.h"
7+
8+
namespace FlexFlow {
9+
10+
template <typename Tree, typename Parent, typename Leaf>
11+
nonnegative_int get_tree_height(
12+
Tree const &tree,
13+
FullBinaryTreeImplementation<Tree, Parent, Leaf> const &impl) {
14+
15+
auto visitor = FullBinaryTreeVisitor<nonnegative_int, Tree, Parent, Leaf>{
16+
[&](Parent const &parent) -> nonnegative_int {
17+
nonnegative_int left_height =
18+
get_tree_height(impl.get_left_child(parent), impl);
19+
nonnegative_int right_height =
20+
get_tree_height(impl.get_right_child(parent), impl);
21+
return std::max(left_height, right_height) + 1_n;
22+
},
23+
[](Leaf const &) -> nonnegative_int { return 0_n; },
24+
};
25+
26+
return visit(tree, impl, visitor);
27+
}
28+
29+
} // namespace FlexFlow
30+
31+
#endif
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H
3+
4+
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
5+
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"
6+
7+
namespace FlexFlow {
8+
9+
BinarySPDecompositionTree
10+
balanced_binary_sp_tree_from_nary(SeriesParallelDecomposition const &nary);
11+
12+
} // namespace FlexFlow
13+
14+
#endif

lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
77
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h"
88
#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h"
9+
#include "utils/nonnegative_int/nonnegative_int.h"
910
#include <unordered_set>
1011

1112
namespace FlexFlow {
@@ -23,6 +24,8 @@ std::unordered_multiset<Node> get_leaves(BinarySPDecompositionTree const &);
2324

2425
SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &);
2526

27+
nonnegative_int get_tree_height(BinarySPDecompositionTree const &);
28+
2629
} // namespace FlexFlow
2730

2831
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_TREE_HEIGHT_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_TREE_HEIGHT_H
3+
4+
#include "utils/full_binary_tree/get_tree_height.h"
5+
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h"
6+
7+
namespace FlexFlow {
8+
9+
template <typename Tree, typename Series, typename Parallel, typename Leaf>
10+
nonnegative_int get_tree_height(
11+
Tree const &tree,
12+
GenericBinarySPDecompositionTreeImplementation<Tree,
13+
Series,
14+
Parallel,
15+
Leaf> const &impl) {
16+
17+
FullBinaryTreeImplementation<Tree, std::variant<Series, Parallel>, Leaf>
18+
full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl);
19+
20+
return get_tree_height(tree, full_binary_impl);
21+
}
22+
23+
} // namespace FlexFlow
24+
25+
#endif
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include "utils/full_binary_tree/get_tree_height.h"
2+
#include "utils/archetypes/value_type.h"
3+
4+
namespace FlexFlow {
5+
6+
using Tree = value_type<0>;
7+
using Parent = value_type<1>;
8+
using Leaf = value_type<2>;
9+
10+
template nonnegative_int
11+
get_tree_height(Tree const &,
12+
FullBinaryTreeImplementation<Tree, Parent, Leaf> const &);
13+
14+
} // namespace FlexFlow

0 commit comments

Comments
 (0)