Skip to content

Commit 6e9aa62

Browse files
committed
refactor function _match_pattern
1 parent cbb3538 commit 6e9aa62

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

tf2onnx/graph_matcher.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ def __init__(self, pattern, allow_reorder=False):
137137
self._pattern = pattern
138138
self._allow_reorder = allow_reorder
139139

140+
@staticmethod
141+
def _is_op_type_same(op, pattern):
142+
if pattern.op_type == "*":
143+
return True
144+
145+
if op.type in pattern.op_type.split('|'):
146+
return True
147+
148+
return False
149+
140150
def _match_pattern(self, pattern, op, tensor):
141151
"""Returns whether an TF expression rooted at `op` matches `pattern`.
142152
@@ -152,16 +162,22 @@ def _match_pattern(self, pattern, op, tensor):
152162
153163
Returns:
154164
True if an TF expression rooted at `op` matches `pattern`.
165+
the condition that op is matched with pattern:
166+
1 op is same:
167+
if pattern.op_type is None or *, then treat as same
168+
or op.type in pattern.op_type.split("|")
169+
2 op.inputs are same with pattern.inputs:
170+
if not pattern.inputs, then treat as same
171+
otherwise, iteratively compare input nodes with pattern.
155172
"""
173+
156174
if pattern.op_type is None:
157175
return True
158176

159-
if pattern.op_type != '*':
160-
if op is None or op.type not in pattern.op_type.split('|'):
161-
return False
162-
163-
self._match_result.add(pattern, op, tensor)
164-
# print("matched", ",".join([op.type + "|" + op.name for op in self._match_result.get_nodes()]))
177+
if self._is_op_type_same(op, pattern):
178+
self._match_result.add(pattern, op, tensor)
179+
else:
180+
return False
165181

166182
if not pattern.inputs:
167183
# If pattern.inputs is empty, skips the rest and accepts all the inputs.

0 commit comments

Comments
 (0)