Skip to content

Commit 23e8ae4

Browse files
Merge pull request #860 from RandySheriffH/rashuai/NMSV4V5
Convert NMS V4&V5
2 parents 557939b + 7bc260f commit 23e8ae4

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

tests/test_backend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,10 +2580,52 @@ def test_non_max_suppression(self):
25802580
box_num = 10
25812581
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
25822582
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2583+
25832584
def func(boxes, scores):
25842585
res1 = tf.image.non_max_suppression(boxes, scores, max_output_size=int(box_num / 2))
25852586
res2 = tf.image.non_max_suppression(boxes, scores, max_output_size=0)
25862587
return tf.identity(res1, name=_TFOUTPUT), tf.identity(res2, name=_TFOUTPUT1)
2588+
2589+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2590+
2591+
@check_opset_min_version(11, "NonMaxSuppressionV4")
2592+
def test_non_max_suppression_v4(self):
2593+
box_num = 10
2594+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2595+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2596+
2597+
def func(boxes, scores):
2598+
ret1, ret2 = tf.image.non_max_suppression_padded(boxes, scores, max_output_size=int(box_num * 2),
2599+
pad_to_max_output_size=True)
2600+
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
2601+
2602+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2603+
2604+
@check_opset_min_version(11, "NonMaxSuppressionV4")
2605+
def test_non_max_suppression_v4_no_padding(self):
2606+
box_num = 10
2607+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2608+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2609+
2610+
def func(boxes, scores):
2611+
ret1, ret2 = tf.image.non_max_suppression_padded(boxes, scores, max_output_size=int(box_num),
2612+
pad_to_max_output_size=False)
2613+
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
2614+
2615+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2616+
2617+
@check_tf_min_version("1.15")
2618+
@check_opset_min_version(11, "NonMaxSuppressionV5")
2619+
def test_non_max_suppression_v5(self):
2620+
box_num = 10
2621+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2622+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2623+
2624+
def func(boxes, scores):
2625+
ret1, ret2 = tf.image.non_max_suppression_with_scores(boxes, scores, max_output_size=int(box_num / 2),
2626+
soft_nms_sigma=0.0)
2627+
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
2628+
25872629
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
25882630

25892631
def _conv1d_test(self, x_val, w, stride=None, padding="VALID", rtol=1e-07):

tf2onnx/onnx_opset/tensor.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,8 @@ def version_10(cls, ctx, node, **kwargs):
14521452
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
14531453

14541454

1455-
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression")
1455+
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3", "NonMaxSuppressionV4", "NonMaxSuppressionV5"],
1456+
onnx_op="NonMaxSuppression")
14561457
class NonMaxSuppression:
14571458
@classmethod
14581459
def version_10(cls, ctx, node, **kwargs):
@@ -1464,18 +1465,41 @@ def version_10(cls, ctx, node, **kwargs):
14641465
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
14651466
# while tf's output is [num_selected_boxes]
14661467
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
1467-
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
1468+
input_score = ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
14681469
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
14691470
# replace original node with nonmaxsurppress + slice + squeeze + cast
1470-
dtypes = [ctx.get_dtype(node.output[0])]
1471-
shapes = [ctx.get_shape(node.output[0])]
1471+
dtypes = [[ctx.get_dtype(output)] for output in node.output]
1472+
shapes = [[ctx.get_shape(output)] for output in node.output]
1473+
max_output_size = node.input[2]
1474+
utils.make_sure(len(node.inputs) <= 5 or int(node.inputs[5].get_tensor_value(False)) == 0,
1475+
"soft_nms_sigma must be 0")
14721476
ctx.remove_node(node.name)
1473-
new_nonmaxsurppress = ctx.make_node(node.type, node.input).output[0]
1477+
new_nonmaxsurppress = ctx.make_node(node.type, node.input[: 5]).output[0]
14741478
slice_op = GraphBuilder(ctx).make_slice({"data": new_nonmaxsurppress,
14751479
"axes": [1], "ends": [3], "starts": [2]})
14761480
squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]})
1477-
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
1478-
name=node.name, outputs=node.output, dtypes=dtypes, shapes=shapes)
1481+
if len(node.input) > 5: # v5, called by ..._with_scores(), pad_to_max_output_size always False
1482+
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
1483+
outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0])
1484+
ctx.make_node("Gather", inputs=[input_score.input[0], squeeze_op.output[0]],
1485+
outputs=[node.output[1]], dtypes=dtypes[1], shapes=shapes[1])
1486+
elif "pad_to_max_output_size" in node.attr: # V4
1487+
shape_op = ctx.make_node("Shape", inputs=[squeeze_op.output[0]])
1488+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64))
1489+
sub_op = ctx.make_node("Sub", inputs=[max_output_size, shape_op.output[0]])
1490+
raw_pad = ctx.make_node("Concat", inputs=[const_zero.output[0], sub_op.output[0]], attr={'axis': 0})
1491+
raw_pad_float = ctx.make_node("Cast", inputs=[raw_pad.output[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
1492+
relu_op = ctx.make_node("Relu", inputs=[raw_pad_float.output[0]])
1493+
pad_val = ctx.make_node("Cast", inputs=[relu_op.output[0]], attr={"to": onnx_pb.TensorProto.INT64})
1494+
pad_op = ctx.make_node("Pad", inputs=[squeeze_op.output[0], pad_val.output[0]])
1495+
ctx.make_node("Cast", inputs=pad_op.output, name="cast_A", attr={"to": onnx_pb.TensorProto.INT32},
1496+
outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0])
1497+
reduce_op = ctx.make_node("ReduceSum", inputs=shape_op.output, attr={"axes": [0], "keepdims": 0})
1498+
ctx.make_node("Cast", inputs=[reduce_op.output[0]], name="cast_B", attr={"to": onnx_pb.TensorProto.INT32},
1499+
outputs=[node.output[1]], dtypes=dtypes[1], shapes=shapes[1])
1500+
else:
1501+
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
1502+
name=node.name, outputs=node.output, dtypes=dtypes[0], shapes=shapes[0])
14791503

14801504
@classmethod
14811505
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)