diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index a7ebc0bff7..9d8e4bc259 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -11,6 +11,7 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/overload.h" namespace FlexFlow { @@ -67,6 +68,27 @@ static std::optional return match; } +MatchAdditionalCriterion additional_criterion_for_subpattern( + MatchAdditionalCriterion const &full_additional_criterion, + bidict const + &full_pattern_values_to_subpattern_inputs) { + return MatchAdditionalCriterion{ + full_additional_criterion.node_criterion, + [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { + return patternValue.visit( + overload{[&](PatternNodeOutput const &) -> bool { + return full_additional_criterion.value_criterion( + patternValue, pcgValue); + }, + [&](PatternInput const &i) -> bool { + PatternValue full_pattern_value = + full_pattern_values_to_subpattern_inputs.at_r(i); + return full_additional_criterion.value_criterion( + full_pattern_value, pcgValue); + }}); + }}; +} + std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &graph, @@ -87,10 +109,18 @@ std::vector PatternSplitResult subpatterns = apply_split(pattern, split); std::vector prefix_matches = find_pattern_matches( - subpatterns.subpattern_1, graph, additional_criterion); + subpatterns.subpattern_1, + graph, + additional_criterion_for_subpattern( + additional_criterion, + subpatterns.full_pattern_values_to_subpattern_1_inputs)); std::vector postfix_matches = find_pattern_matches( - subpatterns.subpattern_2, graph, additional_criterion); + subpatterns.subpattern_2, + graph, + additional_criterion_for_subpattern( + additional_criterion, + subpatterns.full_pattern_values_to_subpattern_2_inputs)); for (UnlabelledDataflowGraphPatternMatch const &prefix_match : prefix_matches) { diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 304bb8cf46..c7b03e24f2 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -7,10 +7,13 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" +#include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" @@ -18,6 +21,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/overload.h" +#include #include namespace FlexFlow { @@ -46,8 +50,13 @@ struct SubgraphConcreteFromPattern { } OpenDataflowValue operator()(PatternInput const &i) const { - return OpenDataflowValue{full_graph_values_to_subgraph_inputs.at_l( - match.input_assignment.at(i))}; + OpenDataflowValue mapped_input = match.input_assignment.at(i); + if (full_graph_values_to_subgraph_inputs.contains_l(mapped_input)) { + return OpenDataflowValue{ + full_graph_values_to_subgraph_inputs.at_l(mapped_input)}; + } else { + return mapped_input; + } } OpenDataflowEdge operator()(InputPatternEdge const &e) const { @@ -148,11 +157,27 @@ bool unlabelled_pattern_does_match( UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { + std::unordered_set matched_by_pattern_inputs = + unordered_set_of(values(match.input_assignment)); + + ASSERT(left_entries(match.node_assignment) == get_nodes(pattern)); + ASSERT( + is_subseteq_of(right_entries(match.node_assignment), get_nodes(graph))); + ASSERT(keys(match.input_assignment) == get_graph_inputs(pattern)); + ASSERT(is_subseteq_of(matched_by_pattern_inputs, + get_open_dataflow_values(graph))); + OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match); OpenDataflowGraphView matched_subgraph = subgraph_result.graph; - assert(left_entries(match.node_assignment) == get_nodes(pattern)); - assert(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); + std::unordered_set full_values_split_by_subgraph = + left_entries(subgraph_result.full_graph_values_to_subgraph_inputs); + + ASSERT(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); + ASSERT(is_subseteq_of(full_values_split_by_subgraph, + get_open_dataflow_values(graph)), + full_values_split_by_subgraph, + get_open_dataflow_values(graph)); MatchAdditionalCriterion through_subgraph_operation = MatchAdditionalCriterion{ diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 8ba1fee873..4dbf0885dd 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -13,144 +13,260 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find_pattern_matches(PCGPattern, SubParallelComputationGraph)") { - ParallelComputationGraphBuilder builder; - - nonnegative_int batch_size = 16_n; - nonnegative_int batch_degree = 2_n; - nonnegative_int num_channels = 24_n; - - TensorShape a_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - num_channels, - }, - }, - DataType::FLOAT, - }; - - std::string a_name = "a"; - - parallel_tensor_guid_t a_tensor = builder.create_input_tensor(a_shape); - a_tensor = - builder.parallel_partition(a_tensor, ff_dim_t{0_n}, batch_degree); - - nonnegative_int outDim = 16_n; - std::string x_matmul_name = "x_matmul"; - std::string y_matmul_name = "y_matmul"; - parallel_tensor_guid_t t0 = - builder.dense(a_tensor, - outDim, - /*activation=*/std::nullopt, - /*use_bias=*/false, - DataType::FLOAT, - /*kernel_initializer=*/std::nullopt, - /*bias_initializer=*/std::nullopt, - x_matmul_name); - parallel_tensor_guid_t t1 = - builder.dense(a_tensor, - outDim, - /*activation=*/std::nullopt, - /*use_bias=*/false, - DataType::FLOAT, - /*kernel_initializer=*/std::nullopt, - /*bias_initializer=*/std::nullopt, - y_matmul_name); - parallel_tensor_guid_t t2 = builder.add(t0, t1); - - ParallelComputationGraph pcg = builder.pcg; - parallel_layer_guid_t x_matmul = - get_parallel_layer_by_name(pcg, x_matmul_name); - parallel_layer_guid_t y_matmul = - get_parallel_layer_by_name(pcg, y_matmul_name); - std::vector x_incoming = - get_incoming_tensors(pcg, x_matmul); - REQUIRE(x_incoming.size() == 2); - parallel_tensor_guid_t x_weights = x_incoming.at(1); - std::vector y_incoming = - get_incoming_tensors(pcg, y_matmul); - REQUIRE(y_incoming.size() == 2); - parallel_tensor_guid_t y_weights = y_incoming.at(1); - - LabelledOpenDataflowGraph - g = LabelledOpenDataflowGraph:: - create>(); - - TensorAttributePattern pattern_tensor_a = - tensor_attribute_pattern_match_all(); - TensorAttributePattern pattern_tensor_b = - tensor_attribute_pattern_match_all(); - TensorAttributePattern pattern_tensor_c = - tensor_attribute_pattern_match_all(); - TensorAttributePattern pattern_tensor_x = - tensor_attribute_pattern_match_all(); - TensorAttributePattern pattern_tensor_y = - tensor_attribute_pattern_match_all(); - - OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{{ - op_type_equals_constraint(OperatorType::LINEAR), - }}; - - OperatorAttributePattern op_pattern_2 = op_pattern_1; - - DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); - DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); - DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); - - NodeAddedResult op_pattern_1_added = - g.add_node(op_pattern_1, - {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, - {pattern_tensor_x}); - PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; - OpenDataflowValue pt_x = - OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; - - NodeAddedResult op_pattern_2_added = - g.add_node(op_pattern_2, - {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_c}}, - {pattern_tensor_y}); - PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; - OpenDataflowValue pt_y = - OpenDataflowValue{get_only(op_pattern_2_added.outputs)}; - - PCGPattern pattern = PCGPattern{g}; - - std::unordered_set result = unordered_set_of( - find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); - - PCGPatternMatch match1 = - PCGPatternMatch{bidict{ - {op_pattern_1_node, x_matmul}, - {op_pattern_2_node, y_matmul}, - }, - bidict{ - {PatternInput{pt_a}, - open_parallel_tensor_guid_from_closed(a_tensor)}, - {PatternInput{pt_b}, - open_parallel_tensor_guid_from_closed(x_weights)}, - {PatternInput{pt_c}, - open_parallel_tensor_guid_from_closed(y_weights)}, - }}; - - PCGPatternMatch match2 = - PCGPatternMatch{bidict{ - {op_pattern_1_node, y_matmul}, - {op_pattern_2_node, x_matmul}, - }, - bidict{ - {PatternInput{pt_a}, - open_parallel_tensor_guid_from_closed(a_tensor)}, - {PatternInput{pt_b}, - open_parallel_tensor_guid_from_closed(y_weights)}, - {PatternInput{pt_c}, - open_parallel_tensor_guid_from_closed(x_weights)}, - }}; - - std::unordered_set correct = {match1, match2}; - - CHECK(result == correct); + SUBCASE("simple case") { + ParallelComputationGraphBuilder builder; + + nonnegative_int batch_size = 16_n; + nonnegative_int batch_degree = 2_n; + nonnegative_int num_channels = 24_n; + + TensorShape a_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + num_channels, + }, + }, + DataType::FLOAT, + }; + + std::string a_name = "a"; + + parallel_tensor_guid_t a_tensor = builder.create_input_tensor(a_shape); + a_tensor = + builder.parallel_partition(a_tensor, ff_dim_t{0_n}, batch_degree); + + nonnegative_int outDim = 16_n; + std::string x_matmul_name = "x_matmul"; + std::string y_matmul_name = "y_matmul"; + parallel_tensor_guid_t t0 = + builder.dense(a_tensor, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + x_matmul_name); + parallel_tensor_guid_t t1 = + builder.dense(a_tensor, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + y_matmul_name); + parallel_tensor_guid_t t2 = builder.add(t0, t1); + + ParallelComputationGraph pcg = builder.pcg; + parallel_layer_guid_t x_matmul = + get_parallel_layer_by_name(pcg, x_matmul_name); + parallel_layer_guid_t y_matmul = + get_parallel_layer_by_name(pcg, y_matmul_name); + std::vector x_incoming = + get_incoming_tensors(pcg, x_matmul); + REQUIRE(x_incoming.size() == 2); + parallel_tensor_guid_t x_weights = x_incoming.at(1); + std::vector y_incoming = + get_incoming_tensors(pcg, y_matmul); + REQUIRE(y_incoming.size() == 2); + parallel_tensor_guid_t y_weights = y_incoming.at(1); + + LabelledOpenDataflowGraph + g = LabelledOpenDataflowGraph:: + create>(); + + TensorAttributePattern pattern_tensor_a = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_b = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_c = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_x = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_y = + tensor_attribute_pattern_match_all(); + + OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + }}; + + OperatorAttributePattern op_pattern_2 = op_pattern_1; + + DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); + DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); + DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); + + NodeAddedResult op_pattern_1_added = + g.add_node(op_pattern_1, + {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, + {pattern_tensor_x}); + PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; + OpenDataflowValue pt_x = + OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; + + NodeAddedResult op_pattern_2_added = + g.add_node(op_pattern_2, + {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_c}}, + {pattern_tensor_y}); + PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; + OpenDataflowValue pt_y = + OpenDataflowValue{get_only(op_pattern_2_added.outputs)}; + + PCGPattern pattern = PCGPattern{g}; + + std::unordered_set result = unordered_set_of( + find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); + + PCGPatternMatch match1 = PCGPatternMatch{ + bidict{ + {op_pattern_1_node, x_matmul}, + {op_pattern_2_node, y_matmul}, + }, + bidict{ + {PatternInput{pt_a}, + open_parallel_tensor_guid_from_closed(a_tensor)}, + {PatternInput{pt_b}, + open_parallel_tensor_guid_from_closed(x_weights)}, + {PatternInput{pt_c}, + open_parallel_tensor_guid_from_closed(y_weights)}, + }}; + + PCGPatternMatch match2 = PCGPatternMatch{ + bidict{ + {op_pattern_1_node, y_matmul}, + {op_pattern_2_node, x_matmul}, + }, + bidict{ + {PatternInput{pt_a}, + open_parallel_tensor_guid_from_closed(a_tensor)}, + {PatternInput{pt_b}, + open_parallel_tensor_guid_from_closed(y_weights)}, + {PatternInput{pt_c}, + open_parallel_tensor_guid_from_closed(x_weights)}, + }}; + + std::unordered_set correct = {match1, match2}; + + CHECK(result == correct); + } + + SUBCASE("pcg is a chain") { + ParallelComputationGraphBuilder builder; + + nonnegative_int batch_size = 16_n; + nonnegative_int batch_degree = 2_n; + nonnegative_int num_channels = 24_n; + + TensorShape a_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + num_channels, + }, + }, + DataType::FLOAT, + }; + + std::string a_name = "a"; + + parallel_tensor_guid_t a_tensor = builder.create_input_tensor(a_shape); + a_tensor = + builder.parallel_partition(a_tensor, ff_dim_t{0_n}, batch_degree); + + nonnegative_int outDim = 16_n; + std::string x_matmul_name = "x_matmul"; + std::string y_matmul_name = "y_matmul"; + parallel_tensor_guid_t t0 = + builder.dense(a_tensor, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + x_matmul_name); + parallel_tensor_guid_t t1 = + builder.dense(t0, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + y_matmul_name); + parallel_tensor_guid_t t2 = + builder.dense(t1, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + parallel_tensor_guid_t t3 = + builder.dense(t2, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + ParallelComputationGraph pcg = builder.pcg; + + LabelledOpenDataflowGraph + g = LabelledOpenDataflowGraph:: + create>(); + + TensorAttributePattern pattern_tensor_a = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_b = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_c = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_x = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_y = + tensor_attribute_pattern_match_all(); + + OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + }}; + + OperatorAttributePattern op_pattern_2 = op_pattern_1; + + DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); + DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); + DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); + + NodeAddedResult op_pattern_1_added = + g.add_node(op_pattern_1, + {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, + {pattern_tensor_x}); + PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; + OpenDataflowValue pt_x = + OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; + + NodeAddedResult op_pattern_2_added = + g.add_node(op_pattern_2, + {OpenDataflowValue{pt_x}, OpenDataflowValue{pt_c}}, + {pattern_tensor_y}); + PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; + + PCGPattern pattern = PCGPattern{g}; + + std::unordered_set result = unordered_set_of( + find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); + + CHECK(result.size() == 3); + } } } diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h index 202058a3d1..f5bbbc228d 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_H +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" #include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" @@ -10,6 +11,17 @@ namespace FlexFlow { OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &, std::unordered_set const &); +bidict + get_full_graph_values_to_subgraph_inputs( + OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes); + +OpenDataflowGraphData + get_subgraph_data(OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + bidict const + &full_graph_values_to_subgraph_inputs); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc index ad3d4f26c0..36f027f792 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -4,7 +4,11 @@ #include "utils/containers/is_subseteq_of.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" @@ -13,100 +17,89 @@ namespace FlexFlow { -struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { - OpenDataflowSubgraph(OpenDataflowGraphView const &full_graph, - std::unordered_set const &subgraph_nodes, - bidict const - &full_graph_values_to_subgraph_inputs) - : full_graph(full_graph), subgraph_nodes(subgraph_nodes), - full_graph_values_to_subgraph_inputs( - full_graph_values_to_subgraph_inputs) { - assert(is_subseteq_of(this->subgraph_nodes, get_nodes(full_graph))); - } - - std::unordered_set query_nodes(NodeQuery const &q) const override { - return intersection(this->full_graph.query_nodes(q), this->subgraph_nodes); - } - - std::unordered_set - query_edges(OpenDataflowEdgeQuery const &q) const override { - std::unordered_set result; - for (OpenDataflowEdge const &open_e : this->full_graph.query_edges(q)) { - open_e.visit(overload{ - [&](DataflowEdge const &e) { - bool contains_src = contains(this->subgraph_nodes, e.src.node); - bool contains_dst = contains(this->subgraph_nodes, e.dst.node); - if (contains_src && contains_dst) { - result.insert(OpenDataflowEdge{e}); - } else if (contains_dst && !contains_src) { - result.insert(OpenDataflowEdge{DataflowInputEdge{ - this->full_graph_values_to_subgraph_inputs.at_l( - OpenDataflowValue{e.src}), - e.dst}}); - } - return std::nullopt; - }, - [&](DataflowInputEdge const &e) { - if (contains(this->subgraph_nodes, e.dst.node)) { - result.insert(OpenDataflowEdge{DataflowInputEdge{ - this->full_graph_values_to_subgraph_inputs.at_l( - OpenDataflowValue{e.src}), - e.dst}}); - } - return std::nullopt; - }}); - } - return result; - } - - std::unordered_set - query_outputs(DataflowOutputQuery const &q) const override { - return filter(this->full_graph.query_outputs(q), - [&](DataflowOutput const &o) { - return contains(this->subgraph_nodes, o.node); - }); - } - - std::unordered_set get_inputs() const override { - return unordered_set_of(values(this->full_graph_values_to_subgraph_inputs)); - }; - - OpenDataflowSubgraph *clone() const override { - return new OpenDataflowSubgraph{ - this->full_graph, - this->subgraph_nodes, - this->full_graph_values_to_subgraph_inputs, - }; - } - -private: - OpenDataflowGraphView full_graph; - std::unordered_set subgraph_nodes; - bidict - full_graph_values_to_subgraph_inputs; -}; - OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &g, std::unordered_set const &subgraph_nodes) { - DataflowGraphInputSource input_source; bidict - full_graph_values_to_subgraph_inputs = generate_bidict( - get_subgraph_inputs(g, subgraph_nodes), - [&](OpenDataflowValue const &v) -> DataflowGraphInput { - return v.visit(overload{ - [](DataflowGraphInput const &i) { return i; }, - [&](DataflowOutput const &) { - return input_source.new_dataflow_graph_input(); - }, - }); - }); + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(g, subgraph_nodes); return OpenDataflowSubgraphResult{ - OpenDataflowGraphView::create( - g, subgraph_nodes, full_graph_values_to_subgraph_inputs), + OpenDataflowGraphView::create( + get_subgraph_data( + g, subgraph_nodes, full_graph_values_to_subgraph_inputs)), full_graph_values_to_subgraph_inputs, }; } +bidict + get_full_graph_values_to_subgraph_inputs( + OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes) { + DataflowGraphInputSource input_source; + return generate_bidict(get_subgraph_inputs(g, subgraph_nodes), + [&](OpenDataflowValue const &v) -> DataflowGraphInput { + return v.visit(overload{ + [](DataflowGraphInput const &i) { return i; }, + [&](DataflowOutput const &) { + return input_source.new_dataflow_graph_input(); + }, + }); + }); +} + +OpenDataflowGraphData + get_subgraph_data(OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + bidict const + &full_graph_values_to_subgraph_inputs) { + std::unordered_set subgraph_input_edges = + transform(get_subgraph_incoming_edges(g, subgraph_nodes), + [&](OpenDataflowEdge const &edge) { + return edge.visit( + overload{[&](DataflowInputEdge const &e) { + return OpenDataflowEdge{DataflowInputEdge{ + full_graph_values_to_subgraph_inputs.at_l( + OpenDataflowValue{e.src}), + e.dst}}; + }, + [&](DataflowEdge const &e) { + return OpenDataflowEdge{DataflowInputEdge{ + full_graph_values_to_subgraph_inputs.at_l( + OpenDataflowValue{e.src}), + e.dst}}; + }}); + }); + + OpenDataflowEdgeQuery subgraph_interior_edges_query = OpenDataflowEdgeQuery{ + DataflowInputEdgeQuery{ + query_set::match_none(), + query_set::match_none(), + query_set::match_none(), + }, + DataflowEdgeQuery{ + query_set{subgraph_nodes}, + query_set::matchall(), + query_set{subgraph_nodes}, + query_set::matchall(), + }, + }; + std::unordered_set subgraph_interior_edges = + g.query_edges(subgraph_interior_edges_query); + + std::unordered_set subgraph_inputs = + unordered_set_of(values(full_graph_values_to_subgraph_inputs)); + std::unordered_set subgraph_outputs = + filter(g.query_outputs(dataflow_output_query_all()), + [&](DataflowOutput const &o) { + return contains(subgraph_nodes, o.node); + }); + return OpenDataflowGraphData{ + subgraph_nodes, + set_union(subgraph_input_edges, subgraph_interior_edges), + subgraph_inputs, + subgraph_outputs, + }; +} + } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/get_subgraph.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/get_subgraph.cc new file mode 100644 index 0000000000..c44e5f81b7 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/get_subgraph.cc @@ -0,0 +1,349 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/containers/contains.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_full_graph_values_to_subgraph_inputs(OpenDataflowGraphView, " + "std::unordered_set) ") { + OpenDataflowGraph graph = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = graph.add_input(); + DataflowGraphInput i1 = graph.add_input(); + DataflowGraphInput i2 = graph.add_input(); + + NodeAddedResult n0_added = graph.add_node({OpenDataflowValue{i0}}, 1_n); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = graph.add_node({v0, OpenDataflowValue{i1}}, 1_n); + Node n1 = n1_added.node; + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + + NodeAddedResult n2_added = graph.add_node({v0}, 1_n); + Node n2 = n2_added.node; + OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; + + NodeAddedResult n3_added = + graph.add_node({OpenDataflowValue{i2}, v1, v2}, 1_n); + Node n3 = n3_added.node; + + std::unordered_set subgraph_nodes = {n1, n2, n3}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + SUBCASE("left entries are correct") { + std::unordered_set correct = { + v0, OpenDataflowValue{i1}, OpenDataflowValue{i2}}; + CHECK(left_entries(full_graph_values_to_subgraph_inputs) == correct); + } + + SUBCASE("mapping is correct") { + CHECK(full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{i1}) == + i1); + CHECK(full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{i2}) == + i2); + std::unordered_set inputs = {i1, i2}; + CHECK(!contains(inputs, full_graph_values_to_subgraph_inputs.at_l(v0))); + } + } + + TEST_CASE( + "get_subgraph_data(OpenDataflowGraphView, std::unordered_set, " + "bidict)") { + SUBCASE("2-node graph without inputs") { + OpenDataflowGraph graph = + OpenDataflowGraph::create(); + + NodeAddedResult n0_added = graph.add_node({}, 1_n); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = graph.add_node({v0}, 1_n); + Node n1 = n1_added.node; + + SUBCASE("subgraph is full graph") { + std::unordered_set subgraph_nodes = {n0, n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + {OpenDataflowEdge{ + DataflowEdge{DataflowOutput{n0, 0_n}, DataflowInput{n1, 0_n}}}}, + {}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is n0") { + std::unordered_set subgraph_nodes = {n0}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + OpenDataflowGraphData correct = OpenDataflowGraphData{subgraph_nodes, + {}, + {}, + {DataflowOutput{ + n0, + 0_n, + }}}; + CHECK(result == correct); + } + + SUBCASE("subgraph is n1") { + std::unordered_set subgraph_nodes = {n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + DataflowGraphInput n0_as_subgraph_input = + full_graph_values_to_subgraph_inputs.at_l(v0); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + {OpenDataflowEdge{DataflowInputEdge{n0_as_subgraph_input, + DataflowInput{n1, 0_n}}}}, + {n0_as_subgraph_input}, + {DataflowOutput{ + n1, + 0_n, + }}}; + CHECK(result == correct); + } + + SUBCASE("subgraph is empty") { + std::unordered_set subgraph_nodes = {}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + OpenDataflowGraphData correct = + OpenDataflowGraphData{subgraph_nodes, {}, {}, {}}; + CHECK(result == correct); + } + } + + SUBCASE("3-node graph with inputs") { + OpenDataflowGraph graph = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = graph.add_input(); + DataflowGraphInput i1 = graph.add_input(); + + NodeAddedResult n0_added = graph.add_node({OpenDataflowValue{i0}}, 1_n); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = + graph.add_node({v0, OpenDataflowValue{i1}}, 1_n); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = graph.add_node({v0}, 1_n); + Node n2 = n2_added.node; + + SUBCASE("subgraph is full graph") { + std::unordered_set subgraph_nodes = {n0, n1, n2}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n1, 0_n}}}, + OpenDataflowEdge{{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n2, 0_n}}}}, + }, + {i0, i1}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + DataflowOutput{ + n2, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n0, n1) split") { + std::unordered_set subgraph_nodes = {n0, n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n1, 0_n}}}, + }, + {i0, i1}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n0, n1) split") { + std::unordered_set subgraph_nodes = {n0, n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n1, 0_n}}}, + }, + {i0, i1}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n0, n2) split") { + std::unordered_set subgraph_nodes = {n0, n2}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n2, 0_n}}}, + }, + {i0}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n2, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n1, n2) split") { + std::unordered_set subgraph_nodes = {n1, n2}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + DataflowGraphInput n0_as_subgraph_input = + full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{v0}); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowInputEdge{n0_as_subgraph_input, + DataflowInput{n1, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{n0_as_subgraph_input, + DataflowInput{n2, 0_n}}}, + }, + {i1, n0_as_subgraph_input}, + { + DataflowOutput{ + n1, + 0_n, + }, + DataflowOutput{ + n2, + 0_n, + }, + }}; + CHECK(result == correct); + } + } + } +}