88# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99# See the License for the specific language governing permissions and
1010# limitations under the License.
11-
12- from typing import List , Set
11+ from typing import Dict , List
1312
1413import networkx as nx
1514import networkx .algorithms .isomorphism as ism
1615
1716from nncf .common .graph .patterns import GraphPattern
1817
1918
20- def is_subgraph_has_inner_outgoing_edges (
21- graph : nx .DiGraph , full_subgraph_with_non_pattern_nodes : List [str ], pattern_subgraph : List [str ]
22- ) -> bool :
19+ def _are_nodes_matched (node_1 , node_2 ) -> bool :
20+ for attr in node_2 :
21+ if attr == GraphPattern .LABEL_ATTR :
22+ continue
23+ if attr == GraphPattern .METATYPE_ATTR :
24+ # GraphPattern.ANY_PATTERN_NODE_TYPE and GraphPattern.NON_PATTERN_NODE_TYPE
25+ # are matched to any node type.
26+ if GraphPattern .ANY_PATTERN_NODE_TYPE in node_2 [attr ] or GraphPattern .NON_PATTERN_NODE_TYPE in node_2 [attr ]:
27+ continue
28+ # Torch and TF pattern mapping based on 'type' section,
29+ # While ONNX mapping based on metatypes -
30+ # to support all of them, we need to check the existane of the attributes
31+ if GraphPattern .NODE_TYPE_ATTR in node_1 :
32+ if node_1 [GraphPattern .NODE_TYPE_ATTR ] in node_2 [attr ]:
33+ continue
34+ if node_1 [attr ] not in node_2 [attr ]:
35+ return False
36+ return True
37+
38+
39+ def _sort_patterns_by_len (pattern : nx .DiGraph ) -> int :
2340 """
24- Checks out whether the 'pattern_subgraph' has outgoing edges,
25- that aren't connected with nodes from full_subgraph_with_non_pattern_nodes.
41+ Sort patterns by their length. GraphPattern.NON_PATTERN_NODE_TYPE is not counted as a pattern node.
42+ """
43+ non_pattern_nodes = [
44+ node_id
45+ for node_id , node_data in pattern .nodes (data = True )
46+ if GraphPattern .NON_PATTERN_NODE_TYPE in node_data [GraphPattern .METATYPE_ATTR ]
47+ ]
48+ return len (pattern ) - len (non_pattern_nodes )
49+
50+
51+ def _is_subgraph_matching_strict (graph : nx .DiGraph , pattern : nx .DiGraph , subgraph : Dict [str , str ]) -> bool :
52+ """
53+ Checks out whether the matched subgraph has:
54+ 1) External predecessors of starting nodes.
55+ 2) External successors of the last nodes.
56+ 3) External successors or predecessors of the nodes which are not starting and last.
57+ If any of these conditions is True, than returns False, otherwise - True.
58+ The checks are skipped for NON_PATTERN_NODE_TYPE.
2659 Example:
60+ This subgraph matching is not strict.
2761 (conv2d + BN + ReLU pattern):
2862 ...
2963 |
@@ -37,119 +71,81 @@ def is_subgraph_has_inner_outgoing_edges(
3771 |
3872 ...
3973 :param graph: The model graph.
40- :param full_subgraph_with_non_pattern_nodes: A subgraph of the model graph including the nodes outside the pattern.
41- :param pattern_subgraph: A subgraph of the model.
42- :return: True if the subgraph contains outgoing edges starting not from the last node,
43- False - otherwise.
74+ :param pattern: The matched pattern.
75+ :param subgraph: A subgraph of the model graph including the nodes outside the pattern.
76+ :return: If any of three conditions is True than returns False, otherwise - True.
77+ """
78+ starting_nodes = []
79+ last_nodes = []
80+ for node in pattern .nodes :
81+ if not pattern .pred [node ] and pattern .succ [node ]:
82+ starting_nodes .append (node )
83+ if pattern .pred [node ] and not pattern .succ [node ]:
84+ last_nodes .append (node )
85+
86+ for node_from_graph , node_from_pattern in subgraph .items ():
87+ if GraphPattern .NON_PATTERN_NODE_TYPE in pattern .nodes [node_from_pattern ].get (GraphPattern .METATYPE_ATTR ):
88+ continue
89+ predecessors_keys = graph .pred [node_from_graph ].keys ()
90+ successor_keys = graph .succ [node_from_graph ].keys ()
91+ has_external_successors = any (successor_key not in subgraph for successor_key in successor_keys )
92+ has_external_predcessors = any (predecessor_key not in subgraph for predecessor_key in predecessors_keys )
93+ if node_from_pattern in starting_nodes and has_external_successors :
94+ return False
95+ if node_from_pattern in last_nodes and has_external_predcessors :
96+ return False
97+ if (node_from_pattern not in last_nodes and node_from_pattern not in starting_nodes ) and (
98+ has_external_successors or has_external_predcessors
99+ ):
100+ return False
101+ return True
102+
103+
104+ def _copy_subgraph_excluding_non_pattern_node (subgraph : Dict [str , str ], pattern_graph : GraphPattern ) -> Dict [str , str ]:
44105 """
45- first_node = pattern_subgraph [0 ]
46- last_node = pattern_subgraph [- 1 ]
47- for node_key in pattern_subgraph :
48- if node_key == last_node :
49- predecessors = list (graph .pred [node_key ].keys ())
50- if any (predecessor not in full_subgraph_with_non_pattern_nodes for predecessor in predecessors ):
51- return True
52- elif node_key == first_node :
53- successors = list (graph .succ [node_key ].keys ())
54- if any (successor not in full_subgraph_with_non_pattern_nodes for successor in successors ):
55- return True
56- else :
57- successors = list (graph .succ [node_key ].keys ())
58- predecessors = list (graph .pred [node_key ].keys ())
59- if any (successors_key not in full_subgraph_with_non_pattern_nodes for successors_key in successors ):
60- return True
61- if any (predecessor not in full_subgraph_with_non_pattern_nodes for predecessor in predecessors ):
62- return True
63- return False
106+ Copies a matching subgraph excluding the nodes having GraphPattern.NON_PATTERN_NODE_TYPE.
107+
108+ :param subgraph: Subgraph
109+ :param pattern_graph: A graph consists of patterns to match.
110+ :return: New subgraph without excluded nodes.
111+ """
112+ output = {}
113+ for node_from_graph , node_from_pattern in subgraph .items ():
114+ pattern_node = pattern_graph .graph .nodes [node_from_pattern ]
115+ pattern_node_types = pattern_node .get (GraphPattern .METATYPE_ATTR )
116+ if GraphPattern .NON_PATTERN_NODE_TYPE not in pattern_node_types :
117+ output [node_from_graph ] = node_from_pattern
118+ return output
64119
65120
66121def find_subgraphs_matching_pattern (graph : nx .DiGraph , pattern_graph : GraphPattern ) -> List [List [str ]]:
67122 """
68- Find a list of subgraphs for the particular graph that match the pattern expression.
123+ Finds a list of nodes which define a subgraph matched a pattern in pattern_graph.
124+ Nodes in each subgraph is stored in lexicographical_topological_sort.
125+
69126 :param graph: The model graph.
70- :param pattern_graph: A graph consists of patterns for layer fusing logic.
71- :return: A list of subgraphs, matching the pattern expression.
72- Each subgraph is defined as a list of node keys.
127+ :param pattern_graph: A graph consists of patterns to match.
128+ :return: A list of subgraphs are mathced to the patterns. Each subgraph is defined as a list of node keys.
73129 """
74-
75- def are_nodes_matching (node_1 , node_2 ):
76- for attr in node_2 :
77- if attr == GraphPattern .LABEL_ATTR :
78- continue
79- if attr == GraphPattern .METATYPE_ATTR :
80- # GraphPattern.ANY_PATTERN_NODE_TYPE and GraphPattern.NON_PATTERN_NODE_TYPE
81- # are matched to any node type.
82-
83- if (
84- GraphPattern .ANY_PATTERN_NODE_TYPE in node_2 [attr ]
85- or GraphPattern .NON_PATTERN_NODE_TYPE in node_2 [attr ]
86- ):
87- continue
88- # Torch and TF pattern mapping based on 'type' section,
89- # While ONNX mapping based on metatypes -
90- # to support all of them, we need to check the existane of the attributes
91- if GraphPattern .NODE_TYPE_ATTR in node_1 :
92- if node_1 [GraphPattern .NODE_TYPE_ATTR ] in node_2 [attr ]:
93- continue
94- if node_1 [attr ] not in node_2 [attr ]:
95- return False
96- return True
97-
98- def are_edges_matching (edge_1 , edge_2 ):
99- for attr in edge_2 :
100- if edge_1 [attr ] not in edge_2 [attr ]:
101- return False
102- return True
103-
104- subgraphs = [] # type: List[List[str]]
105- visited_nodes = set () # type: Set[str]
106- patterns = [] # type: List[nx.DiGraph]
107- for c in nx .weakly_connected_components (pattern_graph .graph ):
108- patterns .append (pattern_graph .graph .subgraph (c ))
109-
110- def sort_patterns (pattern : nx .DiGraph ):
111- """
112- Sort patterns by their length,
113- keeping in mind that if node type is GraphPattern.NON_PATTERN_NODE_TYPE it shouldn't count.
114- """
115- pattern_len = len (pattern )
116- for node in pattern .nodes :
117- if GraphPattern .NON_PATTERN_NODE_TYPE in pattern_graph .graph .nodes .get (node )[GraphPattern .METATYPE_ATTR ]:
118- pattern_len -= 1
119- return pattern_len
120-
121- # Get all patterns sorted by their lengths
122- # as we want match the longest patterns first
123-
124- patterns = sorted (patterns , key = sort_patterns , reverse = True )
125-
130+ subgraphs = []
131+ matched_nodes = set ()
132+ patterns = pattern_graph .get_weakly_connected_subgraphs ()
133+ patterns = sorted (patterns , key = _sort_patterns_by_len , reverse = True )
126134 for pattern in patterns :
127- matcher = ism .DiGraphMatcher (graph , pattern , node_match = are_nodes_matching , edge_match = are_edges_matching )
135+ matcher = ism .DiGraphMatcher (graph , pattern , node_match = _are_nodes_matched )
128136 for subgraph in matcher .subgraph_isomorphisms_iter ():
129- # Bottleneck that need to sort by id for result consistency
130- pattern_subgraph = list (
131- nx .lexicographical_topological_sort (graph .subgraph (subgraph ), key = lambda x : int (x .split ()[0 ]))
132- )
133-
134- full_subgraph_with_non_pattern_nodes = pattern_subgraph [:]
135- outside_pattern_nodes = []
136-
137- # If some nodes are outside the pattern - remove them from pattern_subgraph
138-
139- for node , pattern_node_id in matcher .mapping .items ():
140- pattern_node = pattern_graph .graph .nodes [pattern_node_id ]
141- pattern_node_types = pattern_node .get (GraphPattern .METATYPE_ATTR )
142- if GraphPattern .NON_PATTERN_NODE_TYPE in pattern_node_types :
143- outside_pattern_nodes .append (node )
144- for node in outside_pattern_nodes :
145- pattern_subgraph .remove (node )
146-
147- is_visited_node = any (node in visited_nodes for node in pattern_subgraph )
148- if is_visited_node :
137+ is_matching_strict = _is_subgraph_matching_strict (graph , pattern , subgraph )
138+ if not is_matching_strict :
149139 continue
150- if is_subgraph_has_inner_outgoing_edges (graph , full_subgraph_with_non_pattern_nodes , pattern_subgraph ):
140+
141+ subgraph = _copy_subgraph_excluding_non_pattern_node (subgraph , pattern_graph )
142+ is_any_node_matched = any (node in matched_nodes for node in subgraph )
143+
144+ if is_any_node_matched :
151145 continue
152- visited_nodes .update (pattern_subgraph )
153- subgraphs .append (pattern_subgraph )
154146
155- return subgraphs if subgraphs else []
147+ matched_nodes .update (subgraph )
148+ sorted_nodes_subgraph = list (nx .lexicographical_topological_sort (graph .subgraph (subgraph )))
149+ subgraphs .append (sorted_nodes_subgraph )
150+
151+ return subgraphs
0 commit comments