Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "utils/graph/dataflow_graph/algorithms.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h"
#include "utils/overload.h"

namespace FlexFlow {

Expand Down Expand Up @@ -67,6 +68,27 @@ static std::optional<UnlabelledDataflowGraphPatternMatch>
return match;
}

MatchAdditionalCriterion additional_criterion_for_subpattern(
MatchAdditionalCriterion const &full_additional_criterion,
bidict<PatternValue, PatternInput> const
&full_pattern_values_to_subpattern_inputs) {
return MatchAdditionalCriterion{
full_additional_criterion.node_criterion,
[&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) {
return patternValue.visit<bool>(
overload{[&](PatternNodeOutput const &) -> bool {
return full_additional_criterion.value_criterion(
patternValue, pcgValue);
},
[&](PatternInput const &i) -> bool {
PatternValue full_pattern_value =
full_pattern_values_to_subpattern_inputs.at_r(i);
return full_additional_criterion.value_criterion(
full_pattern_value, pcgValue);
}});
}};
}

std::vector<UnlabelledDataflowGraphPatternMatch>
find_pattern_matches(UnlabelledGraphPattern const &pattern,
OpenDataflowGraphView const &graph,
Expand All @@ -87,10 +109,18 @@ std::vector<UnlabelledDataflowGraphPatternMatch>
PatternSplitResult subpatterns = apply_split(pattern, split);
std::vector<UnlabelledDataflowGraphPatternMatch> prefix_matches =
find_pattern_matches(
subpatterns.subpattern_1, graph, additional_criterion);
subpatterns.subpattern_1,
graph,
additional_criterion_for_subpattern(
additional_criterion,
subpatterns.full_pattern_values_to_subpattern_1_inputs));
std::vector<UnlabelledDataflowGraphPatternMatch> postfix_matches =
find_pattern_matches(
subpatterns.subpattern_2, graph, additional_criterion);
subpatterns.subpattern_2,
graph,
additional_criterion_for_subpattern(
additional_criterion,
subpatterns.full_pattern_values_to_subpattern_2_inputs));

for (UnlabelledDataflowGraphPatternMatch const &prefix_match :
prefix_matches) {
Expand Down
33 changes: 29 additions & 4 deletions lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
#include "substitutions/unlabelled/unlabelled_graph_pattern.h"
#include "utils/bidict/algorithms/left_entries.h"
#include "utils/bidict/algorithms/right_entries.h"
#include "utils/containers/is_subseteq_of.h"
#include "utils/containers/keys.h"
#include "utils/containers/transform.h"
#include "utils/containers/values.h"
#include "utils/graph/dataflow_graph/algorithms.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h"
#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h"
#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h"
#include "utils/overload.h"
#include <libassert/assert.hpp>
#include <memory>

namespace FlexFlow {
Expand Down Expand Up @@ -46,8 +50,13 @@ struct SubgraphConcreteFromPattern {
}

OpenDataflowValue operator()(PatternInput const &i) const {
return OpenDataflowValue{full_graph_values_to_subgraph_inputs.at_l(
match.input_assignment.at(i))};
OpenDataflowValue mapped_input = match.input_assignment.at(i);
if (full_graph_values_to_subgraph_inputs.contains_l(mapped_input)) {
return OpenDataflowValue{
full_graph_values_to_subgraph_inputs.at_l(mapped_input)};
} else {
return mapped_input;
}
}

OpenDataflowEdge operator()(InputPatternEdge const &e) const {
Expand Down Expand Up @@ -148,11 +157,27 @@ bool unlabelled_pattern_does_match(
UnlabelledDataflowGraphPatternMatch const &match,
MatchAdditionalCriterion const &additional_criterion) {

std::unordered_set<OpenDataflowValue> matched_by_pattern_inputs =
unordered_set_of(values(match.input_assignment));

ASSERT(left_entries(match.node_assignment) == get_nodes(pattern));
ASSERT(
is_subseteq_of(right_entries(match.node_assignment), get_nodes(graph)));
ASSERT(keys(match.input_assignment) == get_graph_inputs(pattern));
ASSERT(is_subseteq_of(matched_by_pattern_inputs,
get_open_dataflow_values(graph)));

OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match);
OpenDataflowGraphView matched_subgraph = subgraph_result.graph;

assert(left_entries(match.node_assignment) == get_nodes(pattern));
assert(right_entries(match.node_assignment) == get_nodes(matched_subgraph));
std::unordered_set<OpenDataflowValue> full_values_split_by_subgraph =
left_entries(subgraph_result.full_graph_values_to_subgraph_inputs);

ASSERT(right_entries(match.node_assignment) == get_nodes(matched_subgraph));
ASSERT(is_subseteq_of(full_values_split_by_subgraph,
get_open_dataflow_values(graph)),
full_values_split_by_subgraph,
get_open_dataflow_values(graph));

MatchAdditionalCriterion through_subgraph_operation =
MatchAdditionalCriterion{
Expand Down
Loading
Loading