Skip to content

Commit 6ec695b

Browse files
Merge pull request #1071 from NikolasMarkou/DepthwiseConv2d_translation
#1070 replacing Reshape and Transpose operators to the kernel that are invalid in tensorrt with just a reshaped tensor
2 parents 080c7a4 + f4e511b commit 6ec695b

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

tests/test_backend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
_STRIDE1x1 = [1, 1, 1, 1]
3636
_KERNEL3x3 = [3, 3, 1, 1]
37+
_DILATIONS1x1 = [1, 1, 1, 1]
3738

3839
# names for input and outputs for tests
3940
_TFINPUT = "input"
@@ -348,7 +349,7 @@ def _conv_test(self, x_val, w, strides=None, padding="VALID", dilations=None, rt
348349
if strides is None:
349350
strides = _STRIDE1x1
350351
if dilations is None:
351-
dilations = [1, 1, 1, 1]
352+
dilations = _DILATIONS1x1
352353
def func(x):
353354
kernel = tf.constant(w, dtype=tf.float32, name='k')
354355
conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, dilations=dilations)
@@ -3580,6 +3581,27 @@ def func(y, x):
35803581
self._run_test_case(
35813582
func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)
35823583

3584+
def _conv_kernel_as_input_test(self, x_val, w_val, strides=None,
3585+
padding="VALID", dilations=None, rtol=1e-07):
3586+
if strides is None:
3587+
strides = _STRIDE1x1
3588+
if dilations is None:
3589+
dilations = _DILATIONS1x1
3590+
3591+
def func(x, kernel):
3592+
conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding,
3593+
dilations=dilations)
3594+
return tf.identity(conv, name=_TFOUTPUT)
3595+
3596+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: w_val}, rtol=rtol)
3597+
3598+
def test_conv2d_1_kernel_as_input(self):
3599+
x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC)
3600+
w_val = np.array([[2., 1., 1.],
3601+
[1., 3., 1.],
3602+
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
3603+
self._conv_kernel_as_input_test(x_val, w_val)
3604+
35833605

35843606
if __name__ == '__main__':
35853607
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,24 +117,28 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
117117
if with_kernel:
118118
# Some ONNX convolution ops require to reshape the kernel (ie. depthwise_conv2d).
119119
if new_kernel_shape:
120-
kernel_name = node.input[1]
121-
122-
if ctx.opset < 5:
123-
# Old reshape takes new shape as attribute.
124-
reshape = ctx.insert_new_node_on_input(node, "Reshape", kernel_name)
125-
reshape.set_attr("shape", new_kernel_shape)
126-
reshape.skip_conversion = True
120+
if node.inputs[1].is_const():
121+
input_node = node.inputs[1]
122+
val = input_node.get_tensor_value(as_list=False)
123+
val = np.reshape(val, new_kernel_shape)
124+
input_node.set_tensor_value(val)
127125
else:
128-
# New reshape takes new shape as input[1].
129-
shape_name = utils.make_name(node.name)
130-
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
131-
132-
reshape = ctx.make_node("Reshape", [kernel_name, shape_name])
133-
ctx.replace_input(node, kernel_name, reshape.output[0], 1)
134-
135-
reshape.skip_conversion = True
136-
137-
ctx.set_shape(reshape.output[0], new_kernel_shape)
126+
kernel_name = node.input[1]
127+
if ctx.opset < 5:
128+
# Old reshape takes new shape as attribute.
129+
reshape = ctx.insert_new_node_on_input(node, "Reshape", kernel_name)
130+
reshape.set_attr("shape", new_kernel_shape)
131+
reshape.skip_conversion = True
132+
else:
133+
# New reshape takes new shape as input[1].
134+
shape_name = utils.make_name(node.name)
135+
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
136+
137+
reshape = ctx.make_node("Reshape", [kernel_name, shape_name])
138+
ctx.replace_input(node, kernel_name, reshape.output[0], 1)
139+
140+
reshape.skip_conversion = True
141+
ctx.set_shape(reshape.output[0], new_kernel_shape)
138142

139143
# Get kernel (may have be changed to a reshape above).
140144
kernel_node = node.inputs[1]

0 commit comments

Comments
 (0)