Skip to content

Commit 5810313

Browse files
authored
Merge pull request #326 from zhijxu-MS/push_branch
merge transpose nodes which share same input and perm attr are also same
2 parents d949ceb + 603da0a commit 5810313

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
3+
34
"""Transpose Optimizer."""
45

56
from __future__ import unicode_literals
7+
from collections import defaultdict
68

79
import logging
810

@@ -113,6 +115,27 @@ def post_optimize_action(self):
113115
self._g.update_proto()
114116
self._g.topological_sort(self._g.get_nodes())
115117

118+
def merge_duplicated_transposes(self):
119+
# strategy used in previous procedure is to move transpose nodes down if possible,
120+
# and it means that when a node has n outputs then n transpose will be generated,
121+
# so we should merge them back to one if they can't be eliminated in previous procedure.
122+
graph = self._g
123+
input_transposes_map = defaultdict(list)
124+
for node in graph.get_nodes():
125+
if node.type == "Transpose":
126+
key = (node.input[0], str(node.get_attr("perm").ints))
127+
input_transposes_map[key].append(node)
128+
129+
for transposes in input_transposes_map.values():
130+
# merge transpose nodes into one: make nodes use the output of the first transpose node
131+
transpose_out = transposes[0].output[0]
132+
for node in transposes[1:]:
133+
old_transpose_out = node.output[0]
134+
graph.replace_all_inputs(graph.get_nodes(), old_transpose_out, transpose_out)
135+
136+
# dangling transpose nodes can be deleted
137+
graph.delete_unused_nodes(graph.outputs)
138+
116139
def optimize(self):
117140
previous_counter = self._g.dump_node_statistics()
118141
no_action = False
@@ -140,6 +163,8 @@ def optimize(self):
140163
break
141164

142165
log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
166+
167+
self.merge_duplicated_transposes()
143168
self.post_optimize_action()
144169

145170
current_counter = self._g.dump_node_statistics()

0 commit comments

Comments
 (0)