Skip to content

Commit 8c0ac9e

Browse files
committed
refactor:
recursive function "_match_pattern" not interact with sub call to _match_pattern by object's data member implictly
1 parent 6e9aa62 commit 8c0ac9e

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

tf2onnx/graph_matcher.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def _match_pattern(self, pattern, op, tensor):
161161
pattern tree.
162162
163163
Returns:
164-
True if an TF expression rooted at `op` matches `pattern`.
164+
if matched return True and match_list whose elem is [pattern, op, tensor]
165+
else return False
165166
the condition that op is matched with pattern:
166167
1 op is same:
167168
if pattern.op_type is None or *, then treat as same
@@ -170,21 +171,21 @@ def _match_pattern(self, pattern, op, tensor):
170171
if not pattern.inputs, then treat as same
171172
otherwise, iteratively compare input nodes with pattern.
172173
"""
173-
174+
match_list = []
174175
if pattern.op_type is None:
175-
return True
176+
return True, match_list
176177

177178
if self._is_op_type_same(op, pattern):
178-
self._match_result.add(pattern, op, tensor)
179+
match_list.append([pattern, op, tensor])
179180
else:
180-
return False
181+
return False, match_list
181182

182183
if not pattern.inputs:
183184
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
184-
return True
185+
return True, match_list
185186

186187
if not op or len(op.inputs) != len(pattern.inputs):
187-
return False
188+
return False, match_list
188189

189190
if self._allow_reorder:
190191
inputs = [None] * len(op.inputs)
@@ -203,12 +204,17 @@ def _match_pattern(self, pattern, op, tensor):
203204
else:
204205
pat = list(zip(op.inputs, pattern.inputs))
205206

206-
ret = []
207+
match_flag_of_inputs = []
207208
for input_tensor, input_pattern in pat:
208209
# print("MATCHING", input_pattern.op_type, input_tensor.type)
209-
r = self._match_pattern(input_pattern, input_tensor, input_tensor)
210-
ret.append(r)
211-
return all(ret)
210+
flag, match_list_of_input = self._match_pattern(input_pattern, input_tensor, input_tensor)
211+
match_flag_of_inputs.append(flag)
212+
match_list.extend(match_list_of_input)
213+
return all(match_flag_of_inputs), match_list
214+
215+
def _parse_match_list_to_match_result(self, match_list):
216+
for pattern, op, tensor in match_list:
217+
self._match_result.add(pattern, op, tensor)
212218

213219
def match_op(self, op):
214220
"""Matches `op` against `self._pattern`.
@@ -221,8 +227,10 @@ def match_op(self, op):
221227
None.
222228
"""
223229
self._match_result = MatchResult()
224-
if not self._match_pattern(self._pattern, op, tensor=None):
230+
match_flag, match_list = self._match_pattern(self._pattern, op, tensor=None)
231+
if not match_flag:
225232
return None
233+
self._parse_match_list_to_match_result(match_list)
226234
return self._match_result
227235

228236
def match_ops(self, ops):

0 commit comments

Comments
 (0)