@@ -55,11 +55,10 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
55
55
# transpose input if needed, no need to record shapes on input
56
56
for idx in input_indices :
57
57
parent = node .inputs [idx ]
58
- if node .inputs [idx ].is_const ():
59
- # if input is a constant, transpose that one
60
- if not parent .data_format :
61
- val = parent .get_tensor_value (as_list = False )
62
- parent .set_tensor_value (val .transpose (constants .NHWC_TO_NCHW ))
58
+ if node .inputs [idx ].is_const () and len (ctx .find_output_consumers (node .input [1 ])) == 1 :
59
+ # if input is a constant, transpose that one if we are the only consumer
60
+ val = parent .get_tensor_value (as_list = False )
61
+ parent .set_tensor_value (val .transpose (constants .NHWC_TO_NCHW ))
63
62
else :
64
63
# if input comes from a op, insert transpose op
65
64
input_name = node .input [idx ]
@@ -70,33 +69,27 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
70
69
if shape is not None :
71
70
new_shape = spatial_map (shape , constants .NHWC_TO_NCHW )
72
71
ctx .set_shape (transpose .output [0 ], new_shape )
73
- parent .data_format = "NCHW"
74
72
75
73
# kernel must to be transposed
76
74
if with_kernel :
77
75
parent = node .inputs [1 ]
78
76
need_transpose = True
79
77
if node .inputs [1 ].is_const ():
80
78
# kernel is const - transpose the const if we are the only consumer of const
81
- # TODO: maybe we should make a copy of the const, or look at the other consumers
82
- # if they'd want a transose as well.
83
79
consumers = ctx .find_output_consumers (node .input [1 ])
84
80
if len (consumers ) == 1 :
85
81
val = parent .get_tensor_value (as_list = False )
86
82
val = val .transpose (constants .HWCN_TO_NCHW )
87
83
parent .set_tensor_value (val )
88
- parent .data_format = "NCHW"
89
84
need_transpose = False
90
85
91
86
if need_transpose :
92
87
input_name = node .input [1 ]
93
88
transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
94
89
transpose .set_attr ("perm" , constants .HWCN_TO_NCHW )
95
90
transpose .skip_conversion = True
96
- ctx .copy_shape (input_name , transpose .output [0 ])
97
91
new_shape = spatial_map (ctx .get_shape (input_name ), constants .HWCN_TO_NCHW )
98
92
ctx .set_shape (transpose .output [0 ], new_shape )
99
- parent .data_format = "NCHW"
100
93
101
94
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
102
95
if new_kernel_shape :
@@ -129,7 +122,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
129
122
ctx .set_shape (transpose .output [0 ], output_shape )
130
123
# Transpose TF NHWC shape back to NCHW shape for current ONNX conv node output
131
124
ctx .set_shape (output_name , spatial_map (output_shape , constants .NHWC_TO_NCHW ))
132
- node .data_format = "NCHW"
125
+ node .data_format = "NCHW"
133
126
134
127
135
128
def add_padding (ctx , node , kernel_shape , strides , dilations = None , spatial = 2 ):
0 commit comments