Skip to content

Commit 072504c

Browse files
Aleksei Kashapovandrey-churkin
andauthored
Move some HWFusedPatterns to IgnoredPatterns (#1905)
Changes: Move some patterns from HWFusedPatterns to IgnoredPatterns. Reason for changes: Some of the patterns inside HWFusedPatterns don't really align with the idea behind this class. These patterns were excluded from HWFusedPatterns and included in IgnoredPatterns Related tickets: 112515 Tests: TBD --------- Co-authored-by: Andrey Churkin <andrey.churkin@intel.com>
1 parent a8af396 commit 072504c

File tree

12 files changed

+289
-429
lines changed

12 files changed

+289
-429
lines changed

nncf/common/graph/graph_matching.py

Lines changed: 107 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,56 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
12-
from typing import List, Set
11+
from typing import Dict, List
1312

1413
import networkx as nx
1514
import networkx.algorithms.isomorphism as ism
1615

1716
from nncf.common.graph.patterns import GraphPattern
1817

1918

20-
def is_subgraph_has_inner_outgoing_edges(
21-
graph: nx.DiGraph, full_subgraph_with_non_pattern_nodes: List[str], pattern_subgraph: List[str]
22-
) -> bool:
19+
def _are_nodes_matched(node_1, node_2) -> bool:
20+
for attr in node_2:
21+
if attr == GraphPattern.LABEL_ATTR:
22+
continue
23+
if attr == GraphPattern.METATYPE_ATTR:
24+
# GraphPattern.ANY_PATTERN_NODE_TYPE and GraphPattern.NON_PATTERN_NODE_TYPE
25+
# are matched to any node type.
26+
if GraphPattern.ANY_PATTERN_NODE_TYPE in node_2[attr] or GraphPattern.NON_PATTERN_NODE_TYPE in node_2[attr]:
27+
continue
28+
# Torch and TF pattern mapping based on 'type' section,
29+
# While ONNX mapping based on metatypes -
30+
# to support all of them, we need to check the existane of the attributes
31+
if GraphPattern.NODE_TYPE_ATTR in node_1:
32+
if node_1[GraphPattern.NODE_TYPE_ATTR] in node_2[attr]:
33+
continue
34+
if node_1[attr] not in node_2[attr]:
35+
return False
36+
return True
37+
38+
39+
def _sort_patterns_by_len(pattern: nx.DiGraph) -> int:
2340
"""
24-
Checks out whether the 'pattern_subgraph' has outgoing edges,
25-
that aren't connected with nodes from full_subgraph_with_non_pattern_nodes.
41+
Sort patterns by their length. GraphPattern.NON_PATTERN_NODE_TYPE is not counted as a pattern node.
42+
"""
43+
non_pattern_nodes = [
44+
node_id
45+
for node_id, node_data in pattern.nodes(data=True)
46+
if GraphPattern.NON_PATTERN_NODE_TYPE in node_data[GraphPattern.METATYPE_ATTR]
47+
]
48+
return len(pattern) - len(non_pattern_nodes)
49+
50+
51+
def _is_subgraph_matching_strict(graph: nx.DiGraph, pattern: nx.DiGraph, subgraph: Dict[str, str]) -> bool:
52+
"""
53+
Checks out whether the matched subgraph has:
54+
1) External predecessors of starting nodes.
55+
2) External successors of the last nodes.
56+
3) External successors or predecessors of the nodes which are not starting and last.
57+
If any of these conditions is True, than returns False, otherwise - True.
58+
The checks are skipped for NON_PATTERN_NODE_TYPE.
2659
Example:
60+
This subgraph matching is not strict.
2761
(conv2d + BN + ReLU pattern):
2862
...
2963
|
@@ -37,119 +71,81 @@ def is_subgraph_has_inner_outgoing_edges(
3771
|
3872
...
3973
:param graph: The model graph.
40-
:param full_subgraph_with_non_pattern_nodes: A subgraph of the model graph including the nodes outside the pattern.
41-
:param pattern_subgraph: A subgraph of the model.
42-
:return: True if the subgraph contains outgoing edges starting not from the last node,
43-
False - otherwise.
74+
:param pattern: The matched pattern.
75+
:param subgraph: A subgraph of the model graph including the nodes outside the pattern.
76+
:return: If any of three conditions is True than returns False, otherwise - True.
77+
"""
78+
starting_nodes = []
79+
last_nodes = []
80+
for node in pattern.nodes:
81+
if not pattern.pred[node] and pattern.succ[node]:
82+
starting_nodes.append(node)
83+
if pattern.pred[node] and not pattern.succ[node]:
84+
last_nodes.append(node)
85+
86+
for node_from_graph, node_from_pattern in subgraph.items():
87+
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern.nodes[node_from_pattern].get(GraphPattern.METATYPE_ATTR):
88+
continue
89+
predecessors_keys = graph.pred[node_from_graph].keys()
90+
successor_keys = graph.succ[node_from_graph].keys()
91+
has_external_successors = any(successor_key not in subgraph for successor_key in successor_keys)
92+
has_external_predcessors = any(predecessor_key not in subgraph for predecessor_key in predecessors_keys)
93+
if node_from_pattern in starting_nodes and has_external_successors:
94+
return False
95+
if node_from_pattern in last_nodes and has_external_predcessors:
96+
return False
97+
if (node_from_pattern not in last_nodes and node_from_pattern not in starting_nodes) and (
98+
has_external_successors or has_external_predcessors
99+
):
100+
return False
101+
return True
102+
103+
104+
def _copy_subgraph_excluding_non_pattern_node(subgraph: Dict[str, str], pattern_graph: GraphPattern) -> Dict[str, str]:
44105
"""
45-
first_node = pattern_subgraph[0]
46-
last_node = pattern_subgraph[-1]
47-
for node_key in pattern_subgraph:
48-
if node_key == last_node:
49-
predecessors = list(graph.pred[node_key].keys())
50-
if any(predecessor not in full_subgraph_with_non_pattern_nodes for predecessor in predecessors):
51-
return True
52-
elif node_key == first_node:
53-
successors = list(graph.succ[node_key].keys())
54-
if any(successor not in full_subgraph_with_non_pattern_nodes for successor in successors):
55-
return True
56-
else:
57-
successors = list(graph.succ[node_key].keys())
58-
predecessors = list(graph.pred[node_key].keys())
59-
if any(successors_key not in full_subgraph_with_non_pattern_nodes for successors_key in successors):
60-
return True
61-
if any(predecessor not in full_subgraph_with_non_pattern_nodes for predecessor in predecessors):
62-
return True
63-
return False
106+
Copies a matching subgraph excluding the nodes having GraphPattern.NON_PATTERN_NODE_TYPE.
107+
108+
:param subgraph: Subgraph
109+
:param pattern_graph: A graph consists of patterns to match.
110+
:return: New subgraph without excluded nodes.
111+
"""
112+
output = {}
113+
for node_from_graph, node_from_pattern in subgraph.items():
114+
pattern_node = pattern_graph.graph.nodes[node_from_pattern]
115+
pattern_node_types = pattern_node.get(GraphPattern.METATYPE_ATTR)
116+
if GraphPattern.NON_PATTERN_NODE_TYPE not in pattern_node_types:
117+
output[node_from_graph] = node_from_pattern
118+
return output
64119

65120

66121
def find_subgraphs_matching_pattern(graph: nx.DiGraph, pattern_graph: GraphPattern) -> List[List[str]]:
67122
"""
68-
Find a list of subgraphs for the particular graph that match the pattern expression.
123+
Finds a list of nodes which define a subgraph matched a pattern in pattern_graph.
124+
Nodes in each subgraph is stored in lexicographical_topological_sort.
125+
69126
:param graph: The model graph.
70-
:param pattern_graph: A graph consists of patterns for layer fusing logic.
71-
:return: A list of subgraphs, matching the pattern expression.
72-
Each subgraph is defined as a list of node keys.
127+
:param pattern_graph: A graph consists of patterns to match.
128+
:return: A list of subgraphs are mathced to the patterns. Each subgraph is defined as a list of node keys.
73129
"""
74-
75-
def are_nodes_matching(node_1, node_2):
76-
for attr in node_2:
77-
if attr == GraphPattern.LABEL_ATTR:
78-
continue
79-
if attr == GraphPattern.METATYPE_ATTR:
80-
# GraphPattern.ANY_PATTERN_NODE_TYPE and GraphPattern.NON_PATTERN_NODE_TYPE
81-
# are matched to any node type.
82-
83-
if (
84-
GraphPattern.ANY_PATTERN_NODE_TYPE in node_2[attr]
85-
or GraphPattern.NON_PATTERN_NODE_TYPE in node_2[attr]
86-
):
87-
continue
88-
# Torch and TF pattern mapping based on 'type' section,
89-
# While ONNX mapping based on metatypes -
90-
# to support all of them, we need to check the existane of the attributes
91-
if GraphPattern.NODE_TYPE_ATTR in node_1:
92-
if node_1[GraphPattern.NODE_TYPE_ATTR] in node_2[attr]:
93-
continue
94-
if node_1[attr] not in node_2[attr]:
95-
return False
96-
return True
97-
98-
def are_edges_matching(edge_1, edge_2):
99-
for attr in edge_2:
100-
if edge_1[attr] not in edge_2[attr]:
101-
return False
102-
return True
103-
104-
subgraphs = [] # type: List[List[str]]
105-
visited_nodes = set() # type: Set[str]
106-
patterns = [] # type: List[nx.DiGraph]
107-
for c in nx.weakly_connected_components(pattern_graph.graph):
108-
patterns.append(pattern_graph.graph.subgraph(c))
109-
110-
def sort_patterns(pattern: nx.DiGraph):
111-
"""
112-
Sort patterns by their length,
113-
keeping in mind that if node type is GraphPattern.NON_PATTERN_NODE_TYPE it shouldn't count.
114-
"""
115-
pattern_len = len(pattern)
116-
for node in pattern.nodes:
117-
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_graph.graph.nodes.get(node)[GraphPattern.METATYPE_ATTR]:
118-
pattern_len -= 1
119-
return pattern_len
120-
121-
# Get all patterns sorted by their lengths
122-
# as we want match the longest patterns first
123-
124-
patterns = sorted(patterns, key=sort_patterns, reverse=True)
125-
130+
subgraphs = []
131+
matched_nodes = set()
132+
patterns = pattern_graph.get_weakly_connected_subgraphs()
133+
patterns = sorted(patterns, key=_sort_patterns_by_len, reverse=True)
126134
for pattern in patterns:
127-
matcher = ism.DiGraphMatcher(graph, pattern, node_match=are_nodes_matching, edge_match=are_edges_matching)
135+
matcher = ism.DiGraphMatcher(graph, pattern, node_match=_are_nodes_matched)
128136
for subgraph in matcher.subgraph_isomorphisms_iter():
129-
# Bottleneck that need to sort by id for result consistency
130-
pattern_subgraph = list(
131-
nx.lexicographical_topological_sort(graph.subgraph(subgraph), key=lambda x: int(x.split()[0]))
132-
)
133-
134-
full_subgraph_with_non_pattern_nodes = pattern_subgraph[:]
135-
outside_pattern_nodes = []
136-
137-
# If some nodes are outside the pattern - remove them from pattern_subgraph
138-
139-
for node, pattern_node_id in matcher.mapping.items():
140-
pattern_node = pattern_graph.graph.nodes[pattern_node_id]
141-
pattern_node_types = pattern_node.get(GraphPattern.METATYPE_ATTR)
142-
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_node_types:
143-
outside_pattern_nodes.append(node)
144-
for node in outside_pattern_nodes:
145-
pattern_subgraph.remove(node)
146-
147-
is_visited_node = any(node in visited_nodes for node in pattern_subgraph)
148-
if is_visited_node:
137+
is_matching_strict = _is_subgraph_matching_strict(graph, pattern, subgraph)
138+
if not is_matching_strict:
149139
continue
150-
if is_subgraph_has_inner_outgoing_edges(graph, full_subgraph_with_non_pattern_nodes, pattern_subgraph):
140+
141+
subgraph = _copy_subgraph_excluding_non_pattern_node(subgraph, pattern_graph)
142+
is_any_node_matched = any(node in matched_nodes for node in subgraph)
143+
144+
if is_any_node_matched:
151145
continue
152-
visited_nodes.update(pattern_subgraph)
153-
subgraphs.append(pattern_subgraph)
154146

155-
return subgraphs if subgraphs else []
147+
matched_nodes.update(subgraph)
148+
sorted_nodes_subgraph = list(nx.lexicographical_topological_sort(graph.subgraph(subgraph)))
149+
subgraphs.append(sorted_nodes_subgraph)
150+
151+
return subgraphs

nncf/common/graph/patterns/patterns.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -380,22 +380,10 @@ class HWFusedPatternNames(Enum):
380380
"linear_biased_activation_elementwise", devices=[TargetDevice.ANY, TargetDevice.CPU, TargetDevice.GPU]
381381
)
382382

383-
# TRANSFORMERS
384-
MATMUL_SOFTMAX_MATMUL = PatternDesc("matmul_softmax_matmul", model_types=[ModelType.TRANSFORMER])
385-
SOFTMAX_RESHAPE_MATMUL = PatternDesc("softmax_reshape_matmul", model_types=[ModelType.TRANSFORMER])
386-
SOFTMAX_RESHAPE_TRANSPOSE_GATHER_MATMUL = PatternDesc(
387-
"softmax_reshape_transpose_gather_matmul", model_types=[ModelType.TRANSFORMER]
388-
)
389-
SOFTMAX_RESHAPE_TRANSPOSE_MATMUL = PatternDesc(
390-
"softmax_reshape_transpose_matmul", model_types=[ModelType.TRANSFORMER]
391-
)
392-
STABLE_DIFFUSION = PatternDesc("stable_diffusion", model_types=[ModelType.TRANSFORMER])
393-
394383

395384
class IgnoredPatternNames(Enum):
396385
"""
397386
Describes the patterns, which nodes should be ignored during FakeQuantize placement.
398387
"""
399388

400-
SOFTMAX_MATMUL = PatternDesc("softmax_matmul", model_types=[ModelType.TRANSFORMER])
401-
SOFTMAX_RESHAPE_MATMUL = PatternDesc("softmax_reshape_matmul", model_types=[ModelType.TRANSFORMER])
389+
MULTIHEAD_ATTENTION_OUTPUT = PatternDesc("multihead_attention_output", model_types=[ModelType.TRANSFORMER])

nncf/onnx/hardware/fused_patterns.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,50 +82,6 @@ def create_swish_with_hard_sigmoid() -> GraphPattern:
8282
return pattern
8383

8484

85-
@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.MATMUL_SOFTMAX_MATMUL)
86-
def create_matmul_softmax_matmul() -> GraphPattern:
87-
pattern = GraphPattern()
88-
softmax_1 = pattern.add_node(
89-
**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.ONNXSoftmaxMetatype}
90-
)
91-
mat_mul_1_1 = pattern.add_node(
92-
**{GraphPattern.LABEL_ATTR: "MATMUL_1", GraphPattern.METATYPE_ATTR: om.ONNXLinearMetatype}
93-
)
94-
mat_mul_2_1 = pattern.add_node(
95-
**{GraphPattern.LABEL_ATTR: "MATMUL_2", GraphPattern.METATYPE_ATTR: om.ONNXLinearMetatype}
96-
)
97-
98-
any_1 = pattern.add_node(
99-
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
100-
)
101-
102-
pattern.add_edge(mat_mul_1_1, softmax_1)
103-
pattern.add_edge(softmax_1, mat_mul_2_1)
104-
pattern.add_edge(any_1, mat_mul_2_1)
105-
106-
softmax_2 = pattern.add_node(
107-
**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.ONNXSoftmaxMetatype}
108-
)
109-
add_2 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD", GraphPattern.METATYPE_ATTR: om.ONNXAddLayerMetatype})
110-
mat_mul_1_2 = pattern.add_node(
111-
**{GraphPattern.LABEL_ATTR: "MATMUL_1", GraphPattern.METATYPE_ATTR: om.ONNXLinearMetatype}
112-
)
113-
mat_mul_2_2 = pattern.add_node(
114-
**{GraphPattern.LABEL_ATTR: "MATMUL_2", GraphPattern.METATYPE_ATTR: om.ONNXLinearMetatype}
115-
)
116-
117-
any_2 = pattern.add_node(
118-
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
119-
)
120-
121-
pattern.add_edge(mat_mul_1_2, add_2)
122-
pattern.add_edge(add_2, softmax_2)
123-
pattern.add_edge(softmax_2, mat_mul_2_2)
124-
pattern.add_edge(any_2, mat_mul_2_2)
125-
126-
return pattern
127-
128-
12985
# INPUT PROCESSING
13086

13187

0 commit comments

Comments
 (0)