|
3 | 3 | """Transpose Optimizer."""
|
4 | 4 |
|
5 | 5 | from __future__ import unicode_literals
|
| 6 | +from collections import defaultdict |
6 | 7 |
|
7 | 8 | import logging
|
8 | 9 |
|
@@ -113,6 +114,27 @@ def post_optimize_action(self):
|
113 | 114 | self._g.update_proto()
|
114 | 115 | self._g.topological_sort(self._g.get_nodes())
|
115 | 116 |
|
| 117 | + def merge_transpose_with_same_input(self): |
| 118 | + # strategy used in previous procedure is to move transpose nodes down if possible, |
| 119 | + # and it means that when a node has n outputs then n transpose will be generated , |
| 120 | + # so we should merge them back to one if they can't be eliminated in previous procedure. |
| 121 | + graph = self._g |
| 122 | + input_transposes_map = defaultdict(list) |
| 123 | + for node in graph.get_nodes(): |
| 124 | + if node.type == "Transpose": |
| 125 | + key = (node.input[0], str(node.get_attr("perm").ints)) |
| 126 | + input_transposes_map[key].append(node) |
| 127 | + |
| 128 | + for _, transposes in input_transposes_map.items(): |
| 129 | + # merge transpose nodes into one: make nodes use the output of the first transpose node |
| 130 | + transpose_out = transposes[0].output[0] |
| 131 | + for node in transposes[1:]: |
| 132 | + old_transpose_out = node.output[0] |
| 133 | + graph.replace_all_inputs(graph.get_nodes(), old_transpose_out, transpose_out) |
| 134 | + |
| 135 | + # dangling transpose nodes can be deleted |
| 136 | + graph.delete_unused_nodes(graph.outputs) |
| 137 | + |
116 | 138 | def optimize(self):
|
117 | 139 | previous_counter = self._g.dump_node_statistics()
|
118 | 140 | no_action = False
|
@@ -140,7 +162,9 @@ def optimize(self):
|
140 | 162 | break
|
141 | 163 |
|
142 | 164 | log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
|
| 165 | + |
143 | 166 | self.post_optimize_action()
|
| 167 | + self.merge_transpose_with_same_input() |
144 | 168 |
|
145 | 169 | current_counter = self._g.dump_node_statistics()
|
146 | 170 | transpose_cnt = current_counter["Transpose"]
|
|
0 commit comments