Skip to content

Commit 1a5fc52

Browse files
fix NCHW converting bug when in pad_rewriter
1 parent d3d301a commit 1a5fc52

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tf2onnx/rewriter/conv2d_with_pad_rewriter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,16 @@ def rewrite_conv2d_with_pad(g, ops):
4343
logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name)
4444
paddings_val = np.array(paddings.get_tensor_value())
4545
# can't pad on batch or channel dimensions
46-
if np.any(paddings_val[0]) or np.any(paddings_val[3]):
47-
continue
46+
data_format = conv.get_attr("data_format").s.decode("utf-8")
47+
if data_format == "NHWC":
48+
if np.any(paddings_val[0]) or np.any(paddings_val[3]):
49+
continue
50+
paddings_val = paddings_val[1:3]
51+
else:
52+
if np.any(paddings_val[0]) or np.any(paddings_val[1]):
53+
continue
54+
paddings_val = paddings_val[2:4]
4855

49-
paddings_val = paddings_val[1:3]
5056
paddings_val = paddings_val.transpose().flatten()
5157
g.replace_input(conv, conv.input[0], pad.input[0])
5258
# convert Conv2D

0 commit comments

Comments
 (0)