@@ -161,7 +161,8 @@ def _match_pattern(self, pattern, op, tensor):
161
161
pattern tree.
162
162
163
163
Returns:
164
- True if an TF expression rooted at `op` matches `pattern`.
164
+ if matched return True and match_list whose elem is [pattern, op, tensor]
165
+ else return False
165
166
the condition that op is matched with pattern:
166
167
1 op is same:
167
168
if pattern.op_type is None or *, then treat as same
@@ -170,21 +171,21 @@ def _match_pattern(self, pattern, op, tensor):
170
171
if not pattern.inputs, then treat as same
171
172
otherwise, iteratively compare input nodes with pattern.
172
173
"""
173
-
174
+ match_list = []
174
175
if pattern .op_type is None :
175
- return True
176
+ return True , match_list
176
177
177
178
if self ._is_op_type_same (op , pattern ):
178
- self . _match_result . add ( pattern , op , tensor )
179
+ match_list . append ([ pattern , op , tensor ] )
179
180
else :
180
- return False
181
+ return False , match_list
181
182
182
183
if not pattern .inputs :
183
184
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
184
- return True
185
+ return True , match_list
185
186
186
187
if not op or len (op .inputs ) != len (pattern .inputs ):
187
- return False
188
+ return False , match_list
188
189
189
190
if self ._allow_reorder :
190
191
inputs = [None ] * len (op .inputs )
@@ -203,12 +204,17 @@ def _match_pattern(self, pattern, op, tensor):
203
204
else :
204
205
pat = list (zip (op .inputs , pattern .inputs ))
205
206
206
- ret = []
207
+ match_flag_of_inputs = []
207
208
for input_tensor , input_pattern in pat :
208
209
# print("MATCHING", input_pattern.op_type, input_tensor.type)
209
- r = self ._match_pattern (input_pattern , input_tensor , input_tensor )
210
- ret .append (r )
211
- return all (ret )
210
+ flag , match_list_of_input = self ._match_pattern (input_pattern , input_tensor , input_tensor )
211
+ match_flag_of_inputs .append (flag )
212
+ match_list .extend (match_list_of_input )
213
+ return all (match_flag_of_inputs ), match_list
214
+
215
+ def _parse_match_list_to_match_result (self , match_list ):
216
+ for pattern , op , tensor in match_list :
217
+ self ._match_result .add (pattern , op , tensor )
212
218
213
219
def match_op (self , op ):
214
220
"""Matches `op` against `self._pattern`.
@@ -221,8 +227,10 @@ def match_op(self, op):
221
227
None.
222
228
"""
223
229
self ._match_result = MatchResult ()
224
- if not self ._match_pattern (self ._pattern , op , tensor = None ):
230
+ match_flag , match_list = self ._match_pattern (self ._pattern , op , tensor = None )
231
+ if not match_flag :
225
232
return None
233
+ self ._parse_match_list_to_match_result (match_list )
226
234
return self ._match_result
227
235
228
236
def match_ops (self , ops ):
0 commit comments