Skip to content

Commit 257b45f

Browse files
Fixed bug in depthwise conv when kernel is shared between two nodes (#1384)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 1cb41b4 commit 257b45f

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,19 @@ def func(x):
620620
# rtol is a bit high, 2 values have a bit high error. Maybe use different input data.
621621
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=0.01)
622622

623+
def test_depthwiseconv_shared_kernel(self):
624+
x_shape = [1, 3, 4, 3]
625+
kernel_shape = [3, 3, 3, 3]
626+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
627+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
628+
def func(x, y):
629+
kernel = tf.constant(kernel_val, dtype=tf.float32, name='k')
630+
conv1 = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
631+
conv2 = tf.nn.depthwise_conv2d(y, kernel, strides=[1, 1, 1, 1], padding='VALID')
632+
conv = tf.add(conv1, conv2)
633+
return tf.identity(conv, name=_TFOUTPUT)
634+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_val}, rtol=0.08)
635+
623636
@check_tf_min_version("1.14", "tf depthwise_conv2d dilations")
624637
@check_opset_min_version(11, "non-const pads")
625638
def test_depthwiseconv_dilations(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,28 +117,22 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
117117
if with_kernel:
118118
# Some ONNX convolution ops require to reshape the kernel (ie. depthwise_conv2d).
119119
if new_kernel_shape:
120-
if node.inputs[1].is_const():
121-
input_node = node.inputs[1]
122-
val = input_node.get_tensor_value(as_list=False)
123-
val = np.reshape(val, new_kernel_shape)
124-
input_node.set_tensor_value(val)
120+
kernel_name = node.input[1]
121+
if ctx.opset < 5:
122+
# Old reshape takes new shape as attribute.
123+
reshape = ctx.insert_new_node_on_input(node, "Reshape", kernel_name)
124+
reshape.set_attr("shape", new_kernel_shape)
125+
reshape.skip_conversion = True
125126
else:
126-
kernel_name = node.input[1]
127-
if ctx.opset < 5:
128-
# Old reshape takes new shape as attribute.
129-
reshape = ctx.insert_new_node_on_input(node, "Reshape", kernel_name)
130-
reshape.set_attr("shape", new_kernel_shape)
131-
reshape.skip_conversion = True
132-
else:
133-
# New reshape takes new shape as input[1].
134-
shape_name = utils.make_name(node.name)
135-
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
127+
# New reshape takes new shape as input[1].
128+
shape_name = utils.make_name(node.name)
129+
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
136130

137-
reshape = ctx.make_node("Reshape", [kernel_name, shape_name])
138-
ctx.replace_input(node, kernel_name, reshape.output[0], 1)
131+
reshape = ctx.make_node("Reshape", [kernel_name, shape_name])
132+
ctx.replace_input(node, kernel_name, reshape.output[0], 1)
139133

140-
reshape.skip_conversion = True
141-
ctx.set_shape(reshape.output[0], new_kernel_shape)
134+
reshape.skip_conversion = True
135+
ctx.set_shape(reshape.output[0], new_kernel_shape)
142136

143137
# Get kernel (may have be changed to a reshape above).
144138
kernel_node = node.inputs[1]

0 commit comments

Comments
 (0)