Skip to content

Commit 6ce2eb9

Browse files
committed
merge transpose nodes which share same input and perm attr are also same
current strategy is to move down transpose nodes if possible while this is is not always a good idea for it may generate more transpose nodes which not removed in later procedures. so it's better to merge them back into one after removing procedures.
1 parent d949ceb commit 6ce2eb9

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Transpose Optimizer."""
44

55
from __future__ import unicode_literals
6+
from collections import defaultdict
67

78
import logging
89

@@ -113,6 +114,27 @@ def post_optimize_action(self):
113114
self._g.update_proto()
114115
self._g.topological_sort(self._g.get_nodes())
115116

117+
def merge_transpose_with_same_input(self):
118+
# strategy used in previous procedure is to move transpose nodes down if possible,
119+
# and it means that when a node has n outputs then n transpose will be generated ,
120+
# so we should merge them back to one if they can't be eliminated in previous procedure.
121+
graph = self._g
122+
input_transposes_map = defaultdict(list)
123+
for node in graph.get_nodes():
124+
if node.type == "Transpose":
125+
key = (node.input[0], str(node.get_attr("perm").ints))
126+
input_transposes_map[key].append(node)
127+
128+
for _, transposes in input_transposes_map.items():
129+
# merge transpose nodes into one: make nodes use the output of the first transpose node
130+
transpose_out = transposes[0].output[0]
131+
for node in transposes[1:]:
132+
old_transpose_out = node.output[0]
133+
graph.replace_all_inputs(graph.get_nodes(), old_transpose_out, transpose_out)
134+
135+
# dangling transpose nodes can be deleted
136+
graph.delete_unused_nodes(graph.outputs)
137+
116138
def optimize(self):
117139
previous_counter = self._g.dump_node_statistics()
118140
no_action = False
@@ -140,7 +162,9 @@ def optimize(self):
140162
break
141163

142164
log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
165+
143166
self.post_optimize_action()
167+
self.merge_transpose_with_same_input()
144168

145169
current_counter = self._g.dump_node_statistics()
146170
transpose_cnt = current_counter["Transpose"]

0 commit comments

Comments
 (0)