1
1
# Copyright (c) Microsoft Corporation. All rights reserved.
2
2
# Licensed under the MIT license.
3
+
3
4
"""Transpose Optimizer."""
4
5
5
6
from __future__ import unicode_literals
@@ -114,9 +115,9 @@ def post_optimize_action(self):
114
115
self ._g .update_proto ()
115
116
self ._g .topological_sort (self ._g .get_nodes ())
116
117
117
- def merge_transpose_with_same_input (self ):
118
+ def merge_duplicated_transposes (self ):
118
119
# strategy used in previous procedure is to move transpose nodes down if possible,
119
- # and it means that when a node has n outputs then n transpose will be generated ,
120
+ # and it means that when a node has n outputs then n transpose will be generated,
120
121
# so we should merge them back to one if they can't be eliminated in previous procedure.
121
122
graph = self ._g
122
123
input_transposes_map = defaultdict (list )
@@ -125,7 +126,7 @@ def merge_transpose_with_same_input(self):
125
126
key = (node .input [0 ], str (node .get_attr ("perm" ).ints ))
126
127
input_transposes_map [key ].append (node )
127
128
128
- for _ , transposes in input_transposes_map .items ():
129
+ for transposes in input_transposes_map .values ():
129
130
# merge transpose nodes into one: make nodes use the output of the first transpose node
130
131
transpose_out = transposes [0 ].output [0 ]
131
132
for node in transposes [1 :]:
@@ -163,8 +164,8 @@ def optimize(self):
163
164
164
165
log .debug ("finish after " + str (iteration_cnt ) + " iteration(s)" )
165
166
167
+ self .merge_duplicated_transposes ()
166
168
self .post_optimize_action ()
167
- self .merge_transpose_with_same_input ()
168
169
169
170
current_counter = self ._g .dump_node_statistics ()
170
171
transpose_cnt = current_counter ["Transpose" ]
0 commit comments