Skip to content

Commit c867e52

Browse files
authored
Merge pull request #1046 from onnx/tom/FixDilationPadding
Fixed bug in padding calculation for padding='SAME' when dilations>1
2 parents 4a6125b + 1d5d783 commit c867e52

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

tests/test_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,24 @@ def test_conv2d_6(self):
398398
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
399399
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)
400400

401+
def test_conv2d_dilation_same(self):
402+
x_shape = [1, 35, 35, 288] # NHWC
403+
kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels]
404+
strides = [1, 1, 1, 1] # NHWC
405+
dilations = [1, 3, 1, 1] # NHWC
406+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
407+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
408+
self._conv_test(x_val, kernel_val, strides=strides, padding="SAME", dilations=dilations, rtol=1e-05)
409+
410+
def test_conv2d_dilation_strides_same(self):
411+
x_shape = [1, 35, 35, 288] # NHWC
412+
kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels]
413+
strides = [1, 2, 4, 1] # NHWC
414+
dilations = [1, 3, 1, 1] # NHWC
415+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
416+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
417+
self._conv_test(x_val, kernel_val, strides=strides, padding="SAME", dilations=dilations, rtol=1e-05)
418+
401419
def test_conv3d_1(self):
402420
strides = [1, 1, 1, 1, 1]
403421
dilations = [1, 1, 1, 1, 1]

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
240240
for i in range(spatial):
241241
pad = (
242242
(output_shape[i + 2] - 1) * strides[i]
243-
+ dilations[i] * kernel_shape[i]
243+
+ dilations[i] * (kernel_shape[i] - 1) + 1
244244
- input_shape[i + 2]
245245
)
246246
pad = max(pad, 0)

0 commit comments

Comments
 (0)