@@ -24,7 +24,7 @@ def is_nhwc_transpose(transpose_node):
24
24
25
25
def is_nchw_transpose (transpose_node ):
26
26
perm_attr = transpose_node .get_attr ('perm' )
27
- return transpose_node .type == "Transpose" and perm_attr and perm_attr .ints == [ 0 , 3 , 1 , 2 ]
27
+ return transpose_node .type == "Transpose" and perm_attr and perm_attr .ints == NHWC_TO_NCHW
28
28
29
29
30
30
def is_useless_transpose (transpose_node ):
@@ -87,7 +87,7 @@ def _calculate_new_shape(graph, op):
87
87
# reshape requires tha output shape can only contain one -1, if not some extra op needed.
88
88
input_shape = graph .make_node ("Shape" , [op .input [0 ]]).output [0 ]
89
89
if is_nchw_transpose (op ):
90
- indice = graph .make_const (utils .make_name ("indice" ), np .array ([ 0 , 3 , 1 , 2 ] )).output [0 ]
90
+ indice = graph .make_const (utils .make_name ("indice" ), np .array (NHWC_TO_NCHW )).output [0 ]
91
91
else :
92
92
indice = graph .make_const (utils .make_name ("indice" ), np .array (NCHW_TO_NHWC )).output [0 ]
93
93
@@ -246,7 +246,7 @@ def _switch_transpose_and_node(self, node, trans):
246
246
shape = self ._g .get_shape (node .output [0 ])
247
247
if shape :
248
248
# only nhwc transpose can reach here
249
- new_shape = [shape [i ] for i in [ 0 , 3 , 1 , 2 ] ]
249
+ new_shape = [shape [i ] for i in NHWC_TO_NCHW ]
250
250
self ._g .set_shape (node .output [0 ], new_shape )
251
251
return True
252
252
@@ -302,7 +302,7 @@ def _create_transpose_pairs_after_node(self, node):
302
302
non_nchw_trans_consumers = self ._get_non_nchw_transpose_output_nodes (node )
303
303
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
304
304
for consumer in non_nchw_trans_consumers :
305
- nchw_node = self ._g .make_node ("Transpose" , [node .output [0 ]], attr = {"perm" : [ 0 , 3 , 1 , 2 ] })
305
+ nchw_node = self ._g .make_node ("Transpose" , [node .output [0 ]], attr = {"perm" : NHWC_TO_NCHW })
306
306
nhwc_node = self ._g .make_node ("Transpose" , [nchw_node .output [0 ]], attr = {"perm" : NCHW_TO_NHWC })
307
307
self ._g .replace_input (consumer , node .output [0 ], nhwc_node .output [0 ])
308
308
0 commit comments