Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H

#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h"
#include <optional>

namespace FlexFlow {

std::optional<PCGBinarySPDecomposition>
get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ std::unordered_set<BinaryTreePath>
find_paths_to_leaf(PCGBinarySPDecomposition const &,
parallel_layer_guid_t const &);

PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree(
BinarySPDecompositionTree const &spd_tree);

std::unordered_map<BinaryTreePath, parallel_layer_guid_t>
pcg_sp_tree_get_path_to_leaf_map(PCGBinarySPDecomposition const &);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h"
#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h"
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.h"
#include <optional>

namespace FlexFlow {

std::optional<PCGBinarySPDecomposition>
get_pcg_balanced_binary_sp_decomposition(
ParallelComputationGraph const &pcg) {

std::optional<SeriesParallelDecomposition> spd =
get_pcg_series_parallel_decomposition(pcg);

if (!spd.has_value()) {
return std::nullopt;
}

return pcg_binary_sp_decomposition_from_binary_sp_tree(
balanced_binary_sp_tree_from_nary(spd.value()));
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h"
#include "compiler/series_parallel/pcg/pcg_binary_series_split.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h"
Expand Down Expand Up @@ -68,10 +69,7 @@ BinarySPDecompositionTree
},
[](PCGBinaryParallelSplit const &parallel) -> BinarySPDecompositionTree {
return BinarySPDecompositionTree{
BinaryParallelSplit{
binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()),
binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()),
},
binary_parallel_split_from_pcg_parallel_split(parallel),
};
},
[](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree {
Expand All @@ -82,9 +80,35 @@ BinarySPDecompositionTree
});
}

std::optional<PCGBinarySPDecomposition>
get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) {
NOT_IMPLEMENTED();
PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree(
BinarySPDecompositionTree const &spd_tree) {
return spd_tree.visit<PCGBinarySPDecomposition>(overload{
[](BinarySeriesSplit const &series) -> PCGBinarySPDecomposition {
return PCGBinarySPDecomposition{
PCGBinarySeriesSplit{
pcg_binary_sp_decomposition_from_binary_sp_tree(
series.get_left_child()),
pcg_binary_sp_decomposition_from_binary_sp_tree(
series.get_right_child()),
},
};
},
[](BinaryParallelSplit const &parallel) -> PCGBinarySPDecomposition {
return PCGBinarySPDecomposition{
PCGBinaryParallelSplit{
pcg_binary_sp_decomposition_from_binary_sp_tree(
parallel.get_left_child()),
pcg_binary_sp_decomposition_from_binary_sp_tree(
parallel.get_right_child()),
},
};
},
[](Node const &node) -> PCGBinarySPDecomposition {
return PCGBinarySPDecomposition{
parallel_layer_guid_t{node},
};
},
});
}

std::unordered_multiset<parallel_layer_guid_t>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "test/utils/doctest/fmt/unordered_multiset.h"
#include "test/utils/rapidcheck.h"
#include <doctest/doctest.h>

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("pcg_binary_sp_decomposition_from_binary_sp_tree") {
Node n1 = Node{1};
Node n2 = Node{2};
Node n3 = Node{3};

auto make_binary_series_split = [](BinarySPDecompositionTree const &lhs,
BinarySPDecompositionTree const &rhs) {
return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}};
};

auto make_binary_parallel_split = [](BinarySPDecompositionTree const &lhs,
BinarySPDecompositionTree const &rhs) {
return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}};
};

auto make_binary_leaf = [](Node const &n) {
return BinarySPDecompositionTree{n};
};

auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs,
PCGBinarySPDecomposition const &rhs) {
return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}};
};

auto make_pcg_parallel_split = [](PCGBinarySPDecomposition const &lhs,
PCGBinarySPDecomposition const &rhs) {
return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}};
};

auto make_pcg_leaf = [](Node const &n) {
return PCGBinarySPDecomposition{parallel_layer_guid_t{n}};
};

SUBCASE("single node") {
BinarySPDecompositionTree input = make_binary_leaf(n1);

PCGBinarySPDecomposition result =
pcg_binary_sp_decomposition_from_binary_sp_tree(input);

PCGBinarySPDecomposition expected = make_pcg_leaf(n1);

CHECK(result == expected);
}

SUBCASE("series split") {
BinarySPDecompositionTree input =
make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2));

PCGBinarySPDecomposition result =
pcg_binary_sp_decomposition_from_binary_sp_tree(input);

PCGBinarySPDecomposition expected =
make_pcg_series_split(make_pcg_leaf(n1), make_pcg_leaf(n2));

CHECK(result == expected);
}

SUBCASE("parallel split") {
BinarySPDecompositionTree input = make_binary_parallel_split(
make_binary_leaf(n1), make_binary_leaf(n2));

PCGBinarySPDecomposition result =
pcg_binary_sp_decomposition_from_binary_sp_tree(input);

PCGBinarySPDecomposition expected =
make_pcg_parallel_split(make_pcg_leaf(n1), make_pcg_leaf(n2));

CHECK(result == expected);
}

SUBCASE("bijectiveness") {
BinarySPDecompositionTree original = make_binary_parallel_split(
make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2)),
make_binary_leaf(n3));

PCGBinarySPDecomposition pcg_tree =
pcg_binary_sp_decomposition_from_binary_sp_tree(original);
BinarySPDecompositionTree converted =
binary_sp_tree_from_pcg_sp_tree(pcg_tree);

CHECK(original == converted);
}
}
}
31 changes: 31 additions & 0 deletions lib/utils/include/utils/full_binary_tree/get_tree_height.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_TREE_HEIGHT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_TREE_HEIGHT_H

#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h"
#include "utils/full_binary_tree/visit.h"
#include "utils/nonnegative_int/nonnegative_int.h"

namespace FlexFlow {

template <typename Tree, typename Parent, typename Leaf>
nonnegative_int get_tree_height(
Tree const &tree,
FullBinaryTreeImplementation<Tree, Parent, Leaf> const &impl) {

auto visitor = FullBinaryTreeVisitor<nonnegative_int, Tree, Parent, Leaf>{
[&](Parent const &parent) -> nonnegative_int {
nonnegative_int left_height =
get_tree_height(impl.get_left_child(parent), impl);
nonnegative_int right_height =
get_tree_height(impl.get_right_child(parent), impl);
return std::max(left_height, right_height) + 1_n;
},
[](Leaf const &) -> nonnegative_int { return 0_n; },
};

return visit(tree, impl, visitor);
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BALANCED_BINARY_SP_TREE_FROM_NARY_H

#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"

namespace FlexFlow {

BinarySPDecompositionTree
balanced_binary_sp_tree_from_nary(SeriesParallelDecomposition const &nary);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h"
#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h"
#include "utils/nonnegative_int/nonnegative_int.h"
#include <unordered_set>

namespace FlexFlow {
Expand All @@ -23,6 +24,8 @@ std::unordered_multiset<Node> get_leaves(BinarySPDecompositionTree const &);

SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &);

nonnegative_int get_tree_height(BinarySPDecompositionTree const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_TREE_HEIGHT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_TREE_HEIGHT_H

#include "utils/full_binary_tree/get_tree_height.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h"

namespace FlexFlow {

template <typename Tree, typename Series, typename Parallel, typename Leaf>
nonnegative_int get_tree_height(
Tree const &tree,
GenericBinarySPDecompositionTreeImplementation<Tree,
Series,
Parallel,
Leaf> const &impl) {

FullBinaryTreeImplementation<Tree, std::variant<Series, Parallel>, Leaf>
full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl);

return get_tree_height(tree, full_binary_impl);
}

} // namespace FlexFlow

#endif
14 changes: 14 additions & 0 deletions lib/utils/src/utils/full_binary_tree/get_tree_height.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "utils/full_binary_tree/get_tree_height.h"
#include "utils/archetypes/value_type.h"

namespace FlexFlow {

using Tree = value_type<0>;
using Parent = value_type<1>;
using Leaf = value_type<2>;

template nonnegative_int
get_tree_height(Tree const &,
FullBinaryTreeImplementation<Tree, Parent, Leaf> const &);

} // namespace FlexFlow
Loading
Loading