|
19 | 19 | from __future__ import print_function
|
20 | 20 | from __future__ import unicode_literals
|
21 | 21 |
|
22 |
| -import copy |
23 |
| - |
| 22 | +from itertools import permutations |
24 | 23 | import six
|
25 | 24 |
|
26 | 25 |
|
27 |
| - |
28 | 26 | class OpTypePattern(object):
|
29 | 27 | """A tree pattern that matches TF expressions with certain op types."""
|
30 | 28 |
|
@@ -188,29 +186,24 @@ def _match_pattern(self, pattern, op, tensor):
|
188 | 186 | return False, match_list
|
189 | 187 |
|
190 | 188 | if self._allow_reorder:
|
191 |
| - inputs = [None] * len(op.inputs) |
192 |
| - wanted = copy.copy(pattern.inputs) |
193 |
| - for idx, i in enumerate(op.inputs): |
194 |
| - for j in range(len(wanted)): # pylint: disable=consider-using-enumerate |
195 |
| - if i.type == wanted[j].op_type: |
196 |
| - inputs[idx] = wanted[j] |
197 |
| - del wanted[j] |
198 |
| - break |
199 |
| - for idx, i in enumerate(inputs): |
200 |
| - if i is None: |
201 |
| - inputs[idx] = wanted[0] |
202 |
| - del wanted[0] |
203 |
| - pat = list(zip(op.inputs, inputs)) |
| 189 | + pattern_inputs_list = permutations(pattern.inputs) |
204 | 190 | else:
|
205 |
| - pat = list(zip(op.inputs, pattern.inputs)) |
206 |
| - |
207 |
| - match_flag_of_inputs = [] |
208 |
| - for input_tensor, input_pattern in pat: |
209 |
| - # print("MATCHING", input_pattern.op_type, input_tensor.type) |
210 |
| - flag, match_list_of_input = self._match_pattern(input_pattern, input_tensor, input_tensor) |
211 |
| - match_flag_of_inputs.append(flag) |
212 |
| - match_list.extend(match_list_of_input) |
213 |
| - return all(match_flag_of_inputs), match_list |
| 191 | + pattern_inputs_list = [pattern.inputs] |
| 192 | + |
| 193 | + for possible_pattern_inputs in pattern_inputs_list: |
| 194 | + pat = list(zip(op.inputs, possible_pattern_inputs)) |
| 195 | + match_flag_of_inputs = [] |
| 196 | + match_lists_of_inputs = [] |
| 197 | + for input_tensor, input_pattern in pat: |
| 198 | + # print("MATCHING", input_pattern.op_type, input_tensor.type) |
| 199 | + flag, match_list_of_input = self._match_pattern(input_pattern, input_tensor, input_tensor) |
| 200 | + match_flag_of_inputs.append(flag) |
| 201 | + match_lists_of_inputs.extend(match_list_of_input) |
| 202 | + |
| 203 | + if all(match_flag_of_inputs): |
| 204 | + match_list.extend(match_lists_of_inputs) |
| 205 | + return True, match_list |
| 206 | + return False, match_list |
214 | 207 |
|
215 | 208 | def _parse_match_list_to_match_result(self, match_list):
|
216 | 209 | for pattern, op, tensor in match_list:
|
|
0 commit comments