Skip to content

Commit d441206

Browse files
committed
fix bug in allow_reorder
1 parent 8c0ac9e commit d441206

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

tf2onnx/graph_matcher.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from __future__ import print_function
2020
from __future__ import unicode_literals
2121

22-
import copy
23-
22+
from itertools import permutations
2423
import six
2524

2625

27-
2826
class OpTypePattern(object):
2927
"""A tree pattern that matches TF expressions with certain op types."""
3028

@@ -188,29 +186,24 @@ def _match_pattern(self, pattern, op, tensor):
188186
return False, match_list
189187

190188
if self._allow_reorder:
191-
inputs = [None] * len(op.inputs)
192-
wanted = copy.copy(pattern.inputs)
193-
for idx, i in enumerate(op.inputs):
194-
for j in range(len(wanted)): # pylint: disable=consider-using-enumerate
195-
if i.type == wanted[j].op_type:
196-
inputs[idx] = wanted[j]
197-
del wanted[j]
198-
break
199-
for idx, i in enumerate(inputs):
200-
if i is None:
201-
inputs[idx] = wanted[0]
202-
del wanted[0]
203-
pat = list(zip(op.inputs, inputs))
189+
pattern_inputs_list = permutations(pattern.inputs)
204190
else:
205-
pat = list(zip(op.inputs, pattern.inputs))
206-
207-
match_flag_of_inputs = []
208-
for input_tensor, input_pattern in pat:
209-
# print("MATCHING", input_pattern.op_type, input_tensor.type)
210-
flag, match_list_of_input = self._match_pattern(input_pattern, input_tensor, input_tensor)
211-
match_flag_of_inputs.append(flag)
212-
match_list.extend(match_list_of_input)
213-
return all(match_flag_of_inputs), match_list
191+
pattern_inputs_list = [pattern.inputs]
192+
193+
for possible_pattern_inputs in pattern_inputs_list:
194+
pat = list(zip(op.inputs, possible_pattern_inputs))
195+
match_flag_of_inputs = []
196+
match_lists_of_inputs = []
197+
for input_tensor, input_pattern in pat:
198+
# print("MATCHING", input_pattern.op_type, input_tensor.type)
199+
flag, match_list_of_input = self._match_pattern(input_pattern, input_tensor, input_tensor)
200+
match_flag_of_inputs.append(flag)
201+
match_lists_of_inputs.extend(match_list_of_input)
202+
203+
if all(match_flag_of_inputs):
204+
match_list.extend(match_lists_of_inputs)
205+
return True, match_list
206+
return False, match_list
214207

215208
def _parse_match_list_to_match_result(self, match_list):
216209
for pattern, op, tensor in match_list:

0 commit comments

Comments
 (0)