Skip to content

Commit 73e4e6d

Browse files
committed
Add support MaxPoolWithArgmax operator support
Modified the following to support MaxPoolWithArgmax: new handler in tf2onnx/onnx_opset/nn.py new test cases in tests/backend_test_base.py add an ignored attribute in tf2onnx/tfonnx.py It partially addresses issue #424
1 parent f2f540b commit 73e4e6d

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

tests/test_backend.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ def get_conv_getdata(kind=1):
8787
else:
8888
raise ValueError("kind not known")
8989

90+
def get_maxpoolwithargmax_getdata():
91+
data = [
92+
('SAME', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
93+
('SAME', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]),
94+
('SAME', [1, 10, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1]),
95+
('SAME', [1, 10, 5, 1], [1, 4, 4, 1], [1, 1, 1, 1]),
96+
('VALID', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
97+
('VALID', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]),
98+
]
99+
for idx, v in enumerate(data):
100+
yield (idx,) + v
90101

91102
class BackendTests(Tf2OnnxBackendTestBase):
92103
def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
@@ -2237,6 +2248,24 @@ def test_thresholded_relu(self):
22372248
graph_validator=lambda g: check_op_count(g, "ThresholdedRelu", 1))
22382249
tf.reset_default_graph()
22392250

2251+
@check_tf_min_version("1.13")
2252+
@check_opset_min_version(8, "MaxPoolWithArgmax")
2253+
def test_maxpoolwithargmax(self):
2254+
for tf_shape in ["known", "unknown"]:
2255+
tf.reset_default_graph()
2256+
for p in get_maxpoolwithargmax_getdata():
2257+
_, padding, x_shape, ksize, strides = p
2258+
tf.reset_default_graph()
2259+
x_val = make_xval(x_shape)
2260+
if tf_shape == "known":
2261+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
2262+
else:
2263+
x = tf.placeholder(tf.float32, shape=[None] * x_val.ndim, name=_TFINPUT)
2264+
mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding)
2265+
_ = tf.identity(mp[0], name=_TFOUTPUT)
2266+
_ = tf.identity(mp[1], name=_TFOUTPUT1)
2267+
self.logger.debug(str(p))
2268+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
22402269

22412270
if __name__ == '__main__':
22422271
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,28 @@ def version_4(cls, ctx, node, **kwargs):
319319
add_padding(ctx, node, kernel_shape, strides)
320320
conv_convert_inputs(ctx, node, with_kernel=False)
321321

322+
@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")
323+
class MaxPoolWithArgmaxOp:
324+
@classmethod
325+
def version_8(cls, ctx, node, **kwargs):
326+
# T output = MaxPool(T input, @list(int) ksize, @list(int) strides, @string padding, @string data_format)
327+
328+
# Set kernel_shape attribute
329+
kernel_shape = node.get_attr("ksize").ints
330+
kernel_shape = [kernel_shape[1], kernel_shape[2]]
331+
node.set_attr("kernel_shape", kernel_shape)
332+
333+
# Set strides attribute
334+
strides = node.get_attr("strides").ints
335+
strides = [strides[1], strides[2]]
336+
node.set_attr("strides", strides)
337+
338+
# The input data_format is NHWC for TF MaxPoolWithArgmax
339+
node.set_attr("data_format", "NHWC")
340+
341+
add_padding(ctx, node, kernel_shape, strides)
342+
conv_convert_inputs(ctx, node, with_kernel=False, input_indices=[0], output_indices=[0, 1])
343+
322344

323345
@tf_op(["BiasAdd", "BiasAddV1"])
324346
class BiasAdd:

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def tflist_to_onnx(node_list, shape_override):
4747
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
4848
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "dynamic_size", "Tmultiples",
4949
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
50-
"Tout", "Tlabels", "Tindex", "element_shape"]
50+
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax"]
5151
# some stats
5252
op_cnt = collections.Counter()
5353
attr_cnt = collections.Counter()

0 commit comments

Comments
 (0)