Skip to content

Commit 0a9db3f

Browse files
authored
Merge pull request #850 from jignparm/jignparm/fix_conv2d_with_pad
Fix conv2d with pad rewriter
2 parents 1f2b37d + 8e3368f commit 0a9db3f

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

tests/test_backend.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def func(x):
403403
return tf.identity(conv, name=_TFOUTPUT)
404404
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-5)
405405

406-
def test_conv2d_with_pad(self):
406+
def test_conv2d_with_pad_valid(self):
407407
x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC)
408408
w = np.random.random_sample([3, 3, 1, 2]).astype(np.float32)
409409
strides = [1, 1, 1, 1]
@@ -414,6 +414,17 @@ def func(x):
414414
return tf.identity(conv, name=_TFOUTPUT)
415415
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-5)
416416

417+
def test_conv2d_with_pad_same(self):
418+
x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC)
419+
w = np.random.random_sample([3, 3, 1, 2]).astype(np.float32)
420+
strides = [1, 1, 1, 1]
421+
def func(x):
422+
kernel = tf.constant(w, dtype=tf.float32, name='k')
423+
x_pad = tf.pad(x, paddings=[[0, 0], [2, 2], [2, 2], [0, 0]])
424+
conv = tf.nn.conv2d(x_pad, kernel, strides=strides, padding="SAME")
425+
return tf.identity(conv, name=_TFOUTPUT)
426+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-5)
427+
417428
def test_conv2d_transpose(self):
418429
x_shape = [2, 6, 4, 3]
419430
output_shape = [2, 13, 9, 2]

tf2onnx/rewriter/conv2d_with_pad_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def rewrite_conv2d_with_pad(g, ops):
3737
if mode not in [None, "constant"] or len(pad.input) >= 3:
3838
continue
3939
# Conv2D already has a pad
40-
if conv.get_attr("padding") == "SAME":
40+
if conv.get_attr("padding").s.decode("utf-8") == "SAME":
4141
continue
4242

4343
logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name)

0 commit comments

Comments
 (0)