Skip to content

Commit c3c9c50

Browse files
Added support for converting NonMaxSuppression in opset 10 with dynamic padding
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent e28da28 commit c3c9c50

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

tests/test_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,7 +2822,7 @@ def func(boxes, scores):
28222822

28232823
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
28242824

2825-
@check_opset_min_version(11, "NonMaxSuppressionV4")
2825+
@check_opset_min_version(10, "NonMaxSuppressionV4")
28262826
def test_non_max_suppression_v4(self):
28272827
box_num = 10
28282828
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
@@ -2835,7 +2835,7 @@ def func(boxes, scores):
28352835

28362836
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
28372837

2838-
@check_opset_min_version(11, "NonMaxSuppressionV4")
2838+
@check_opset_min_version(10, "NonMaxSuppressionV4")
28392839
def test_non_max_suppression_v4_no_padding(self):
28402840
box_num = 10
28412841
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
@@ -2849,7 +2849,7 @@ def func(boxes, scores):
28492849
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
28502850

28512851
@check_tf_min_version("1.15")
2852-
@check_opset_min_version(11, "NonMaxSuppressionV5")
2852+
@check_opset_min_version(10, "NonMaxSuppressionV5")
28532853
def test_non_max_suppression_v5(self):
28542854
box_num = 10
28552855
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414

1515
import numpy as np
16-
from onnx import onnx_pb
16+
from onnx import onnx_pb, helper
1717
from onnx.onnx_pb import TensorProto
1818

1919
from tf2onnx import constants, utils
@@ -1485,13 +1485,18 @@ def version_10(cls, ctx, node, **kwargs):
14851485
outputs=[node.output[1]], dtypes=dtypes[1], shapes=shapes[1])
14861486
elif "pad_to_max_output_size" in node.attr: # V4
14871487
shape_op = ctx.make_node("Shape", inputs=[squeeze_op.output[0]])
1488-
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64))
14891488
sub_op = ctx.make_node("Sub", inputs=[max_output_size, shape_op.output[0]])
1490-
raw_pad = ctx.make_node("Concat", inputs=[const_zero.output[0], sub_op.output[0]], attr={'axis': 0})
1491-
raw_pad_float = ctx.make_node("Cast", inputs=[raw_pad.output[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
1489+
raw_pad_float = ctx.make_node("Cast", inputs=[sub_op.output[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
14921490
relu_op = ctx.make_node("Relu", inputs=[raw_pad_float.output[0]])
1493-
pad_val = ctx.make_node("Cast", inputs=[relu_op.output[0]], attr={"to": onnx_pb.TensorProto.INT64})
1494-
pad_op = ctx.make_node("Pad", inputs=[squeeze_op.output[0], pad_val.output[0]])
1491+
pad_amt = ctx.make_node("Cast", inputs=[relu_op.output[0]], attr={"to": onnx_pb.TensorProto.INT64})
1492+
if ctx.opset <= 10: # Dynamic padding not supported before opset 11
1493+
zero_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[0])
1494+
padding = ctx.make_node("ConstantOfShape", inputs=[pad_amt.output[0]], attr={"value": zero_tensor})
1495+
pad_op = ctx.make_node("Concat", inputs=[squeeze_op.output[0], padding.output[0]], attr={'axis': 0})
1496+
else:
1497+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64))
1498+
pad_val = ctx.make_node("Concat", inputs=[const_zero.output[0], pad_amt.output[0]], attr={'axis': 0})
1499+
pad_op = ctx.make_node("Pad", inputs=[squeeze_op.output[0], pad_val.output[0]])
14951500
ctx.make_node("Cast", inputs=pad_op.output, name="cast_A", attr={"to": onnx_pb.TensorProto.INT32},
14961501
outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0], op_name_scope=node.name)
14971502
reduce_op = ctx.make_node("ReduceSum", inputs=shape_op.output, attr={"axes": [0], "keepdims": 0})

0 commit comments

Comments
 (0)