@@ -74,40 +74,40 @@ def pre_optimize_action(self):
74
74
self ._g .topological_sort (self ._g .get_nodes ())
75
75
76
76
def post_optimize_action (self ):
77
+ def _calculate_new_shape (graph , op ):
78
+ input_shape = graph .get_shape (op .input [0 ])
79
+ if input_shape .count (- 1 ) <= 1 :
80
+ if is_nchw_transpose (op ):
81
+ new_shape = [input_shape [0 ], input_shape [3 ], input_shape [1 ], input_shape [2 ]]
82
+ else :
83
+ new_shape = [input_shape [0 ], input_shape [2 ], input_shape [3 ], input_shape [1 ]]
84
+ return graph .make_const (utils .make_name ("new_shape" ), np .array (new_shape , dtype = np .int64 )).output [0 ]
85
+
86
+ # reshape requires tha output shape can only contain one -1, if not some extra op needed.
87
+ input_shape = graph .make_node ("Shape" , [op .input [0 ]]).output [0 ]
88
+ if is_nchw_transpose (op ):
89
+ indice = graph .make_const (utils .make_name ("indice" ), np .array ([0 , 3 , 1 , 2 ])).output [0 ]
90
+ else :
91
+ indice = graph .make_const (utils .make_name ("indice" ), np .array ([0 , 2 , 3 , 1 ])).output [0 ]
92
+
93
+ return graph .make_node ("Gather" , [input_shape , indice ]).output [0 ]
94
+
77
95
nodes = self .nodes
78
96
# if channel==1 or height==width==1, replace transpose with reshape
97
+ # replacing trans with reshape is because transpose will copy data even if this transpose doesn't nothing
79
98
for op in nodes :
80
99
if op .type == "Transpose" :
81
100
input_shape = self ._g .get_shape (op .input [0 ])
82
101
if not input_shape :
83
102
continue
84
- # reshape only supports one dime is -1
85
- if input_shape .count (- 1 ) > 1 :
86
- continue
87
103
88
- new_shape = []
89
- # when transpose is NHWC_TO_NCHW
90
- if is_nchw_transpose (op ) and (input_shape [3 ] == 1 or (input_shape [1 ] == 1 and input_shape [2 ] == 1 )):
91
- new_shape = [input_shape [0 ], input_shape [3 ], input_shape [1 ], input_shape [2 ]]
92
- # when transpose is NCHW_TO_NHWC
93
- if is_nhwc_transpose (op ) and (input_shape [1 ] == 1 or (input_shape [2 ] == 1 and input_shape [3 ] == 1 )):
94
- new_shape = [input_shape [0 ], input_shape [2 ], input_shape [3 ], input_shape [1 ]]
95
- if new_shape :
96
- out_nodes = self ._g .find_output_consumers (op .output [0 ])
97
- need_insert_reshape = False
98
- for out_node in out_nodes :
99
- if out_node .type != "Reshape" :
100
- need_insert_reshape = True
101
- if need_insert_reshape :
102
- op_name = utils .make_name ("reshape" )
103
- shape_name = utils .make_name (op_name )
104
- self ._g .make_const (shape_name , np .array (new_shape , dtype = np .int64 ))
105
- self ._g .remove_node (op .name )
106
- self ._g .make_node ("Reshape" , inputs = [op .input [0 ], shape_name ], outputs = op .output ,
107
- name = op_name )
108
- else :
109
- self ._remove_useless_tranpose (op )
110
- self ._g .topological_sort (self ._g .get_nodes ())
104
+ if (is_nchw_transpose (op ) and (input_shape [3 ] == 1 or (input_shape [1 :3 ] == [1 , 1 ])))\
105
+ or (is_nhwc_transpose (op ) and (input_shape [1 ] == 1 or (input_shape [2 :4 ] == [1 , 1 ]))):
106
+ new_shape = _calculate_new_shape (self ._g , op )
107
+ # replace transpose with reshape
108
+ self ._g .remove_node (op .name )
109
+ self ._g .make_node ("Reshape" , [op .input [0 ], new_shape ], name = op .name , outputs = op .output )
110
+ self ._g .topological_sort (self ._g .get_nodes ())
111
111
112
112
def merge_duplicated_transposes (self ):
113
113
# strategy used in previous procedure is to move transpose nodes down if possible,
0 commit comments