Skip to content

Commit 1d5d783

Browse files
Fixed bug in padding calculation for padding='SAME' when dilations>1
1 parent 54ad341 commit 1d5d783

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

tests/test_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,24 @@ def test_conv2d_6(self):
395395
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
396396
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)
397397

398+
def test_conv2d_dilation_same(self):
399+
x_shape = [1, 35, 35, 288] # NHWC
400+
kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels]
401+
strides = [1, 1, 1, 1] # NHWC
402+
dilations = [1, 3, 1, 1] # NHWC
403+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
404+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
405+
self._conv_test(x_val, kernel_val, strides=strides, padding="SAME", dilations=dilations, rtol=1e-05)
406+
407+
def test_conv2d_dilation_strides_same(self):
408+
x_shape = [1, 35, 35, 288] # NHWC
409+
kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels]
410+
strides = [1, 2, 4, 1] # NHWC
411+
dilations = [1, 3, 1, 1] # NHWC
412+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
413+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
414+
self._conv_test(x_val, kernel_val, strides=strides, padding="SAME", dilations=dilations, rtol=1e-05)
415+
398416
def test_conv3d_1(self):
399417
strides = [1, 1, 1, 1, 1]
400418
dilations = [1, 1, 1, 1, 1]

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
240240
for i in range(spatial):
241241
pad = (
242242
(output_shape[i + 2] - 1) * strides[i]
243-
+ dilations[i] * kernel_shape[i]
243+
+ dilations[i] * (kernel_shape[i] - 1) + 1
244244
- input_shape[i + 2]
245245
)
246246
pad = max(pad, 0)

0 commit comments

Comments
 (0)