@@ -137,6 +137,16 @@ def __init__(self, pattern, allow_reorder=False):
137
137
self ._pattern = pattern
138
138
self ._allow_reorder = allow_reorder
139
139
140
+ @staticmethod
141
+ def _is_op_type_same (op , pattern ):
142
+ if pattern .op_type == "*" :
143
+ return True
144
+
145
+ if op .type in pattern .op_type .split ('|' ):
146
+ return True
147
+
148
+ return False
149
+
140
150
def _match_pattern (self , pattern , op , tensor ):
141
151
"""Returns whether an TF expression rooted at `op` matches `pattern`.
142
152
@@ -152,16 +162,22 @@ def _match_pattern(self, pattern, op, tensor):
152
162
153
163
Returns:
154
164
True if an TF expression rooted at `op` matches `pattern`.
165
+ the condition that op is matched with pattern:
166
+ 1 op is same:
167
+ if pattern.op_type is None or *, then treat as same
168
+ or op.type in pattern.op_type.split("|")
169
+ 2 op.inputs are same with pattern.inputs:
170
+ if not pattern.inputs, then treat as same
171
+ otherwise, iteratively compare input nodes with pattern.
155
172
"""
173
+
156
174
if pattern .op_type is None :
157
175
return True
158
176
159
- if pattern .op_type != '*' :
160
- if op is None or op .type not in pattern .op_type .split ('|' ):
161
- return False
162
-
163
- self ._match_result .add (pattern , op , tensor )
164
- # print("matched", ",".join([op.type + "|" + op.name for op in self._match_result.get_nodes()]))
177
+ if self ._is_op_type_same (op , pattern ):
178
+ self ._match_result .add (pattern , op , tensor )
179
+ else :
180
+ return False
165
181
166
182
if not pattern .inputs :
167
183
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
0 commit comments