Skip to content

Commit 41bc6d5

Browse files
authored
Merge pull request #741 from RandySheriffH/rashuai/FixDepthwiseConv
Rashuai/fix depthwise conv
2 parents 836e89c + b71b44c commit 41bc6d5

File tree

2 files changed

+25
-30
lines changed

2 files changed

+25
-30
lines changed

tests/test_backend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,7 @@ def test_depthwiseconv_0(self):
431431
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
432432
conv = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
433433
_ = tf.identity(conv, name=_TFOUTPUT)
434-
# rtol is a bit high, 2 values have a bit high error. Maybe use different input data.
435-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.08)
434+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-6)
436435

437436
def test_depthwiseconv_1(self):
438437
x_shape = [1, 112, 112, 32]
@@ -443,8 +442,7 @@ def test_depthwiseconv_1(self):
443442
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
444443
conv = tf.nn.depthwise_conv2d(x, kernel, strides=_STRIDE1x1, padding='VALID')
445444
_ = tf.identity(conv, name=_TFOUTPUT)
446-
# rtol is a bit high, 2 values have a bit high error. Maybe use different input data.
447-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.08)
445+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-6)
448446

449447
def test_dropout(self):
450448
is_training = tf.placeholder_with_default(False, (), "is_training")

tf2onnx/onnx_opset/nn.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
7272

7373
# kernel must to be transposed
7474
if with_kernel:
75+
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
76+
if new_kernel_shape:
77+
if ctx.opset < 5:
78+
# old reshape takes new shape as attribute
79+
input_name = node.input[1]
80+
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
81+
reshape.set_attr("shape", new_kernel_shape)
82+
reshape.skip_conversion = True
83+
else:
84+
# new reshape takes new shape as input[1]
85+
shape_name = utils.make_name(node.name)
86+
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
87+
input_name = node.input[1]
88+
reshape = ctx.make_node("Reshape", [input_name, shape_name])
89+
ctx.replace_input(node, input_name, reshape.output[0])
90+
reshape.skip_conversion = True
91+
ctx.set_shape(reshape.output[0], new_kernel_shape)
92+
7593
parent = node.inputs[1]
7694
need_transpose = True
7795
if node.inputs[1].is_const():
@@ -91,24 +109,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
91109
new_shape = spatial_map(ctx.get_shape(input_name), constants.HWCN_TO_NCHW)
92110
ctx.set_shape(transpose.output[0], new_shape)
93111

94-
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
95-
if new_kernel_shape:
96-
if ctx.opset < 5:
97-
# old reshape takes new shape as attribute
98-
input_name = node.input[1]
99-
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
100-
reshape.set_attr("shape", new_kernel_shape)
101-
reshape.skip_conversion = True
102-
else:
103-
# new reshape takes new shape as input[1]
104-
shape_name = utils.make_name(node.name)
105-
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
106-
input_name = node.input[1]
107-
reshape = ctx.make_node("Reshape", [input_name, shape_name])
108-
ctx.replace_input(node, input_name, reshape.output[0])
109-
reshape.skip_conversion = True
110-
ctx.set_shape(reshape.output[0], new_kernel_shape)
111-
112112
# transpose outputs if needed
113113
if node.is_nhwc():
114114
for idx in output_indices:
@@ -280,24 +280,21 @@ def version_1(cls, ctx, node, **kwargs):
280280
if len(input_shape) != 4:
281281
raise ValueError("only Conv2D is supported")
282282

283-
if node.is_nhwc():
284-
i_n, i_h, i_w, i_c = input_shape
285-
else:
286-
i_n, i_c, i_h, i_w = input_shape
287-
288283
kernel_shape = ctx.get_shape(node.input[1])
289284
if len(kernel_shape) != 4:
290285
raise ValueError("only Conv2D is supported")
291286
k_h, k_w, k_input_channels, k_channel_multiplier = kernel_shape
292-
k_output_channels = i_c * k_channel_multiplier
287+
if k_input_channels < 1:
288+
raise ValueError("input channel must be positive")
289+
k_output_channels = k_input_channels * k_channel_multiplier
293290

294291
node.set_attr("kernel_shape", [k_h, k_w])
295292
strides = conv_dims_attr(node, "strides")
296293
conv_dims_attr(node, "dilations")
297-
node.set_attr("group", i_c)
294+
node.set_attr("group", k_input_channels)
298295
add_padding(ctx, node, kernel_shape, strides)
299296

300-
new_kernel_shape = [k_output_channels, 1, k_h, k_w]
297+
new_kernel_shape = [k_h, k_w, 1, k_output_channels]
301298
conv_convert_inputs(ctx, node, with_kernel=True, new_kernel_shape=new_kernel_shape)
302299

303300

0 commit comments

Comments
 (0)