Skip to content

Commit f9b864f

Browse files
committed
add v4 v5 support
1 parent 557939b commit f9b864f

File tree

2 files changed

+72
-8
lines changed

2 files changed

+72
-8
lines changed

tests/test_backend.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2580,13 +2580,54 @@ def test_non_max_suppression(self):
25802580
box_num = 10
25812581
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
25822582
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2583+
25832584
def func(boxes, scores):
25842585
res1 = tf.image.non_max_suppression(boxes, scores, max_output_size=int(box_num / 2))
25852586
res2 = tf.image.non_max_suppression(boxes, scores, max_output_size=0)
25862587
return tf.identity(res1, name=_TFOUTPUT), tf.identity(res2, name=_TFOUTPUT1)
2588+
2589+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2590+
2591+
@check_opset_min_version(11, "NonMaxSuppressionV4")
2592+
def test_non_max_suppression_v4(self):
2593+
box_num = 10
2594+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2595+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2596+
2597+
def func(boxes, scores):
2598+
ret1, ret2 = tf.image.non_max_suppression_padded(boxes, scores, max_output_size=int(box_num * 2),
2599+
pad_to_max_output_size=True)
2600+
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
2601+
2602+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2603+
2604+
@check_opset_min_version(11, "NonMaxSuppressionV4")
2605+
def test_non_max_suppression_v4_no_padding(self):
2606+
box_num = 10
2607+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2608+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2609+
2610+
def func(boxes, scores):
2611+
ret1, ret2 = tf.image.non_max_suppression_padded(boxes, scores, max_output_size=int(box_num / 2),
2612+
pad_to_max_output_size=True)
2613+
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
2614+
2615+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2616+
2617+
@check_opset_min_version(11, "NonMaxSuppressionV5")
2618+
def test_non_max_suppression_v5(self):
2619+
box_num = 10
2620+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2621+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2622+
2623+
def func(boxes, scores):
2624+
ret1, ret2 = tf.image.non_max_suppression_with_scores(boxes, scores, max_output_size=int(box_num / 2),
2625+
soft_nms_sigma=0.0)
2626+
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
2627+
25872628
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
25882629

2589-
def _conv1d_test(self, x_val, w, stride=None, padding="VALID", rtol=1e-07):
2630+
def _conv1d_test(self, x_val, w, s_tride=None, padding="VALID", rtol=1e-07):
25902631
if stride is None:
25912632
stride = 1
25922633
def func(x):

tf2onnx/onnx_opset/tensor.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,8 @@ def version_10(cls, ctx, node, **kwargs):
14521452
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
14531453

14541454

1455-
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression")
1455+
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3", "NonMaxSuppressionV4", "NonMaxSuppressionV5"],
1456+
onnx_op="NonMaxSuppression")
14561457
class NonMaxSuppression:
14571458
@classmethod
14581459
def version_10(cls, ctx, node, **kwargs):
@@ -1464,18 +1465,40 @@ def version_10(cls, ctx, node, **kwargs):
14641465
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
14651466
# while tf's output is [num_selected_boxes]
14661467
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
1467-
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
1468+
input_score = ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
14681469
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
14691470
# replace original node with nonmaxsurppress + slice + squeeze + cast
1470-
dtypes = [ctx.get_dtype(node.output[0])]
1471-
shapes = [ctx.get_shape(node.output[0])]
1471+
dtypes = [[ctx.get_dtype(output)] for output in node.output]
1472+
shapes = [[ctx.get_shape(output)] for output in node.output]
1473+
max_output_size = node.input[2]
14721474
ctx.remove_node(node.name)
1473-
new_nonmaxsurppress = ctx.make_node(node.type, node.input).output[0]
1475+
new_nonmaxsurppress = ctx.make_node(node.type, node.input[: 5]).output[0]
14741476
slice_op = GraphBuilder(ctx).make_slice({"data": new_nonmaxsurppress,
14751477
"axes": [1], "ends": [3], "starts": [2]})
14761478
squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]})
1477-
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
1478-
name=node.name, outputs=node.output, dtypes=dtypes, shapes=shapes)
1479+
if len(node.input) > 5: # V5
1480+
logger.warning("NonMaxSuppressionV5 only parltially supported, soft_nms_sigma must be 0.0")
1481+
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
1482+
outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0])
1483+
ctx.make_node("Gather", inputs=[input_score.input[0], squeeze_op.output[0]],
1484+
outputs=[node.output[1]], dtypes=dtypes[1], shapes=shapes[1])
1485+
elif "pad_to_max_output_size" in node.attr: # V4
1486+
shape_op = ctx.make_node("Shape", inputs=[squeeze_op.output[0]])
1487+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64))
1488+
sub_op = ctx.make_node("Sub", inputs=[max_output_size, shape_op.output[0]])
1489+
raw_pad = ctx.make_node("Concat", inputs=[const_zero.output[0], sub_op.output[0]], attr={'axis': 0})
1490+
raw_pad_float = ctx.make_node("Cast", inputs=[raw_pad.output[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
1491+
relu_op = ctx.make_node("Relu", inputs=[raw_pad_float.output[0]])
1492+
pad_val = ctx.make_node("Cast", inputs=[relu_op.output[0]], attr={"to": onnx_pb.TensorProto.INT64})
1493+
pad_op = ctx.make_node("Pad", inputs=[squeeze_op.output[0], pad_val.output[0]])
1494+
ctx.make_node("Cast", inputs=pad_op.output, name="cast_A", attr={"to": onnx_pb.TensorProto.INT32},
1495+
outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0])
1496+
reduce_op = ctx.make_node("ReduceSum", inputs=shape_op.output, attr={"axes": [0], "keepdims": 0})
1497+
ctx.make_node("Cast", inputs=[reduce_op.output[0]], name="cast_B", attr={"to": onnx_pb.TensorProto.INT32},
1498+
outputs=[node.output[1]], dtypes=dtypes[1], shapes=shapes[1])
1499+
else:
1500+
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
1501+
name=node.name, outputs=node.output, dtypes=dtypes[0], shapes=shapes[0])
14791502

14801503
@classmethod
14811504
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)