|
1 | 1 | # Copyright (c) Microsoft Corporation. All rights reserved.
|
2 | 2 | # Licensed under the MIT license.
|
| 3 | + |
3 | 4 | """Transpose Optimizer."""
|
4 | 5 |
|
5 | 6 | from __future__ import unicode_literals
|
| 7 | +from collections import defaultdict |
6 | 8 |
|
7 | 9 | import logging
|
8 | 10 |
|
@@ -113,6 +115,27 @@ def post_optimize_action(self):
|
113 | 115 | self._g.update_proto()
|
114 | 116 | self._g.topological_sort(self._g.get_nodes())
|
115 | 117 |
|
| 118 | + def merge_duplicated_transposes(self): |
| 119 | + # strategy used in previous procedure is to move transpose nodes down if possible, |
| 120 | + # and it means that when a node has n outputs then n transpose will be generated, |
| 121 | + # so we should merge them back to one if they can't be eliminated in previous procedure. |
| 122 | + graph = self._g |
| 123 | + input_transposes_map = defaultdict(list) |
| 124 | + for node in graph.get_nodes(): |
| 125 | + if node.type == "Transpose": |
| 126 | + key = (node.input[0], str(node.get_attr("perm").ints)) |
| 127 | + input_transposes_map[key].append(node) |
| 128 | + |
| 129 | + for transposes in input_transposes_map.values(): |
| 130 | + # merge transpose nodes into one: make nodes use the output of the first transpose node |
| 131 | + transpose_out = transposes[0].output[0] |
| 132 | + for node in transposes[1:]: |
| 133 | + old_transpose_out = node.output[0] |
| 134 | + graph.replace_all_inputs(graph.get_nodes(), old_transpose_out, transpose_out) |
| 135 | + |
| 136 | + # dangling transpose nodes can be deleted |
| 137 | + graph.delete_unused_nodes(graph.outputs) |
| 138 | + |
116 | 139 | def optimize(self):
|
117 | 140 | previous_counter = self._g.dump_node_statistics()
|
118 | 141 | no_action = False
|
@@ -140,6 +163,8 @@ def optimize(self):
|
140 | 163 | break
|
141 | 164 |
|
142 | 165 | log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
|
| 166 | + |
| 167 | + self.merge_duplicated_transposes() |
143 | 168 | self.post_optimize_action()
|
144 | 169 |
|
145 | 170 | current_counter = self._g.dump_node_statistics()
|
|
0 commit comments