Skip to content

Commit 25973f3

Browse files
authored
Merge pull request #498 from chinhuang007/add-maxpoolwithargmax
Add MaxPoolWithArgmax operator support
2 parents 03a6379 + 73e4e6d commit 25973f3

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

tests/test_backend.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ def get_conv_getdata(kind=1):
8787
else:
8888
raise ValueError("kind not known")
8989

90+
def get_maxpoolwithargmax_getdata():
91+
data = [
92+
('SAME', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
93+
('SAME', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]),
94+
('SAME', [1, 10, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1]),
95+
('SAME', [1, 10, 5, 1], [1, 4, 4, 1], [1, 1, 1, 1]),
96+
('VALID', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
97+
('VALID', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]),
98+
]
99+
for idx, v in enumerate(data):
100+
yield (idx,) + v
90101

91102
class BackendTests(Tf2OnnxBackendTestBase):
92103
def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
@@ -2236,6 +2247,24 @@ def test_thresholded_relu(self):
22362247
graph_validator=lambda g: check_op_count(g, "ThresholdedRelu", 1))
22372248
tf.reset_default_graph()
22382249

2250+
@check_tf_min_version("1.13")
2251+
@check_opset_min_version(8, "MaxPoolWithArgmax")
2252+
def test_maxpoolwithargmax(self):
2253+
for tf_shape in ["known", "unknown"]:
2254+
tf.reset_default_graph()
2255+
for p in get_maxpoolwithargmax_getdata():
2256+
_, padding, x_shape, ksize, strides = p
2257+
tf.reset_default_graph()
2258+
x_val = make_xval(x_shape)
2259+
if tf_shape == "known":
2260+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
2261+
else:
2262+
x = tf.placeholder(tf.float32, shape=[None] * x_val.ndim, name=_TFINPUT)
2263+
mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding)
2264+
_ = tf.identity(mp[0], name=_TFOUTPUT)
2265+
_ = tf.identity(mp[1], name=_TFOUTPUT1)
2266+
self.logger.debug(str(p))
2267+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
22392268

22402269
if __name__ == '__main__':
22412270
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,28 @@ def _convert(cls, ctx, node, **kwargs):
327327
add_padding(ctx, node, kernel_shape, strides)
328328
conv_convert_inputs(ctx, node, with_kernel=False)
329329

330+
@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")
331+
class MaxPoolWithArgmaxOp:
332+
@classmethod
333+
def version_8(cls, ctx, node, **kwargs):
334+
# T output = MaxPool(T input, @list(int) ksize, @list(int) strides, @string padding, @string data_format)
335+
336+
# Set kernel_shape attribute
337+
kernel_shape = node.get_attr("ksize").ints
338+
kernel_shape = [kernel_shape[1], kernel_shape[2]]
339+
node.set_attr("kernel_shape", kernel_shape)
340+
341+
# Set strides attribute
342+
strides = node.get_attr("strides").ints
343+
strides = [strides[1], strides[2]]
344+
node.set_attr("strides", strides)
345+
346+
# The input data_format is NHWC for TF MaxPoolWithArgmax
347+
node.set_attr("data_format", "NHWC")
348+
349+
add_padding(ctx, node, kernel_shape, strides)
350+
conv_convert_inputs(ctx, node, with_kernel=False, input_indices=[0], output_indices=[0, 1])
351+
330352

331353
@tf_op(["BiasAdd", "BiasAddV1"])
332354
class BiasAdd:

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def tflist_to_onnx(node_list, shape_override):
4747
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
4848
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "dynamic_size", "Tmultiples",
4949
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
50-
"Tout", "Tlabels", "Tindex", "element_shape"]
50+
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax"]
5151
# some stats
5252
op_cnt = collections.Counter()
5353
attr_cnt = collections.Counter()

0 commit comments

Comments
 (0)