Skip to content

Commit 02fd7bd

Browse files
committed
fix bug in maxpool
ksize and strides are not always NHWC, actual their formats are aligned with input data's format
1 parent 119f602 commit 02fd7bd

File tree

4 files changed

+45
-14
lines changed

4 files changed

+45
-14
lines changed

tests/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from distutils.version import LooseVersion
1313
from parameterized import parameterized
1414
import numpy as np
15+
import tensorflow as tf
1516
from tf2onnx import constants, logging, utils
1617

1718
__all__ = [
@@ -22,6 +23,7 @@
2223
"check_tf_min_version",
2324
"check_tf_max_version",
2425
"skip_tf_versions",
26+
"skip_if_tf_cpu",
2527
"check_onnxruntime_min_version",
2628
"check_opset_min_version",
2729
"check_opset_max_version",
@@ -181,6 +183,15 @@ def skip_tf_versions(excluded_versions, message=""):
181183
return unittest.skipIf(condition, reason)
182184

183185

186+
def is_tf_gpu():
187+
return tf.test.is_gpu_available()
188+
189+
190+
def skip_if_tf_cpu(message=""):
191+
is_tf_cpu = not is_tf_gpu()
192+
return unittest.skipIf(is_tf_cpu, message)
193+
194+
184195
def check_opset_min_version(min_required_version, message=""):
185196
""" Skip if opset < min_required_version """
186197
config = get_test_config()

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,19 @@ def test_maxpool(self):
258258
self.logger.debug(str(p))
259259
self._run_test_case([_OUTPUT], {_INPUT: x_val})
260260

261+
@skip_if_tf_cpu("only tf_gpu can run maxpool with NCHW format")
262+
def test_maxpool_gpu(self):
263+
# make sure converter behaves well when data format is NCHW
264+
# and when data format is NCHW, only gpu version of tensorflow can run it.
265+
ksize = [1, 1, 2, 2]
266+
strides = [1, 1, 2, 2]
267+
x_val = make_xval([1, 3, 50, 80])
268+
for padding in ["SAME", "VALID"]:
269+
x = tf.placeholder(tf.float32, shape=[None] * 4, name=_TFINPUT)
270+
mp = tf.nn.max_pool(x, ksize, strides, padding=padding, data_format="NCHW")
271+
_ = tf.identity(mp, name=_TFOUTPUT)
272+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
273+
261274
@check_onnxruntime_incompatibility("AveragePool")
262275
def test_avgpool(self):
263276
for tf_shape in ["known", "unknown"]:

tf2onnx/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,13 @@ def data_format(self, val):
136136
self.set_attr("data_format", val)
137137

138138
def is_nhwc(self):
139-
"""Return True if node is in NCHW format."""
139+
"""Return True if node is in NHWC format."""
140140
return self.data_format == "NHWC"
141141

142+
def is_nchw(self):
143+
"""Return True if node is in NCHW format."""
144+
return self.data_format == "NCHW"
145+
142146
def is_const(self):
143147
"""Return True if node is a constant."""
144148
return self.type in ["Const", "ConstV2"]

tf2onnx/onnx_opset/nn.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -301,25 +301,28 @@ def _convert(cls, ctx, node, **kwargs):
301301
# T Y = MaxPool(T X, @AttrType.STRING auto_pad, @AttrType.INTS kernel_shape, @AttrType.INTS pads,
302302
# @AttrType.INTS strides)
303303
# above seems wrong - input[1] is ksize, input[2] is strides
304+
# stride and ksize in tf is not always NHWC, so watch out when converting into onnx's HCHW
304305
if len(node.input) < 3:
305-
kernel_shape = node.get_attr("ksize").ints
306-
kernel_shape = [kernel_shape[1], kernel_shape[2]]
307-
node.set_attr("kernel_shape", kernel_shape)
308-
strides = conv_dims_attr(node, "strides")
306+
kernel_shape_tf = node.get_attr("ksize").ints
307+
strides_tf = node.get_attr("strides").ints
309308
else:
310-
kernel_shape = node.inputs[1].get_tensor_value()
311-
kernel_shape = [kernel_shape[1], kernel_shape[2]]
312-
node.set_attr("kernel_shape", kernel_shape)
313-
314-
strides = node.inputs[2].get_tensor_value()
315-
strides = [strides[1], strides[2]]
316-
node.set_attr("strides", strides)
317-
309+
kernel_shape_tf = node.inputs[1].get_tensor_value()
310+
strides_tf = node.inputs[2].get_tensor_value()
318311
ctx.remove_input(node, node.input[2])
319312
ctx.remove_input(node, node.input[1])
320313

314+
if node.is_nhwc():
315+
kernel_shape_hw = kernel_shape_tf[1:3]
316+
strides_hw = strides_tf[1:3]
317+
elif node.is_nchw():
318+
kernel_shape_hw = kernel_shape_tf[2:4]
319+
strides_hw = strides_tf[2:4]
320+
else:
321+
logger.warning("unexpected data format, please check it")
322+
node.set_attr("kernel_shape", kernel_shape_hw)
323+
node.set_attr("strides", strides_hw)
321324
conv_dims_attr(node, "dilations")
322-
add_padding(ctx, node, kernel_shape, strides)
325+
add_padding(ctx, node, kernel_shape_hw, strides_hw)
323326
conv_convert_inputs(ctx, node, with_kernel=False)
324327

325328

0 commit comments

Comments
 (0)