@@ -72,6 +72,24 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
72
72
73
73
# kernel must to be transposed
74
74
if with_kernel :
75
+ # some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
76
+ if new_kernel_shape :
77
+ if ctx .opset < 5 :
78
+ # old reshape takes new shape as attribute
79
+ input_name = node .input [1 ]
80
+ reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
81
+ reshape .set_attr ("shape" , new_kernel_shape )
82
+ reshape .skip_conversion = True
83
+ else :
84
+ # new reshape takes new shape as input[1]
85
+ shape_name = utils .make_name (node .name )
86
+ ctx .make_const (shape_name , np .array (new_kernel_shape , dtype = np .int64 ))
87
+ input_name = node .input [1 ]
88
+ reshape = ctx .make_node ("Reshape" , [input_name , shape_name ])
89
+ ctx .replace_input (node , input_name , reshape .output [0 ])
90
+ reshape .skip_conversion = True
91
+ ctx .set_shape (reshape .output [0 ], new_kernel_shape )
92
+
75
93
parent = node .inputs [1 ]
76
94
need_transpose = True
77
95
if node .inputs [1 ].is_const ():
@@ -91,24 +109,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
91
109
new_shape = spatial_map (ctx .get_shape (input_name ), constants .HWCN_TO_NCHW )
92
110
ctx .set_shape (transpose .output [0 ], new_shape )
93
111
94
- # some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
95
- if new_kernel_shape :
96
- if ctx .opset < 5 :
97
- # old reshape takes new shape as attribute
98
- input_name = node .input [1 ]
99
- reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
100
- reshape .set_attr ("shape" , new_kernel_shape )
101
- reshape .skip_conversion = True
102
- else :
103
- # new reshape takes new shape as input[1]
104
- shape_name = utils .make_name (node .name )
105
- ctx .make_const (shape_name , np .array (new_kernel_shape , dtype = np .int64 ))
106
- input_name = node .input [1 ]
107
- reshape = ctx .make_node ("Reshape" , [input_name , shape_name ])
108
- ctx .replace_input (node , input_name , reshape .output [0 ])
109
- reshape .skip_conversion = True
110
- ctx .set_shape (reshape .output [0 ], new_kernel_shape )
111
-
112
112
# transpose outputs if needed
113
113
if node .is_nhwc ():
114
114
for idx in output_indices :
@@ -280,24 +280,21 @@ def version_1(cls, ctx, node, **kwargs):
280
280
if len (input_shape ) != 4 :
281
281
raise ValueError ("only Conv2D is supported" )
282
282
283
- if node .is_nhwc ():
284
- i_n , i_h , i_w , i_c = input_shape
285
- else :
286
- i_n , i_c , i_h , i_w = input_shape
287
-
288
283
kernel_shape = ctx .get_shape (node .input [1 ])
289
284
if len (kernel_shape ) != 4 :
290
285
raise ValueError ("only Conv2D is supported" )
291
286
k_h , k_w , k_input_channels , k_channel_multiplier = kernel_shape
292
- k_output_channels = i_c * k_channel_multiplier
287
+ if k_input_channels < 1 :
288
+ raise ValueError ("input channel must be positive" )
289
+ k_output_channels = k_input_channels * k_channel_multiplier
293
290
294
291
node .set_attr ("kernel_shape" , [k_h , k_w ])
295
292
strides = conv_dims_attr (node , "strides" )
296
293
conv_dims_attr (node , "dilations" )
297
- node .set_attr ("group" , i_c )
294
+ node .set_attr ("group" , k_input_channels )
298
295
add_padding (ctx , node , kernel_shape , strides )
299
296
300
- new_kernel_shape = [k_output_channels , 1 , k_h , k_w ]
297
+ new_kernel_shape = [k_h , k_w , 1 , k_output_channels ]
301
298
conv_convert_inputs (ctx , node , with_kernel = True , new_kernel_shape = new_kernel_shape )
302
299
303
300
0 commit comments