Skip to content

Commit ca4087e

Browse files
committed
support NonMaxSuppression
1 parent 022b718 commit ca4087e

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,5 +2117,19 @@ def test_isinf(self):
21172117
self._run_test_case([_OUTPUT], {_INPUT: x_val})
21182118
tf.reset_default_graph()
21192119

2120+
@check_opset_min_version(10, "NonMaxSuppression")
2121+
def test_non_max_suppression(self):
2122+
box_num = 10
2123+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2124+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2125+
boxes = tf.placeholder(tf.float32, shape=[None, 4], name=_TFINPUT)
2126+
scores = tf.placeholder(tf.float32, shape=[None], name=_TFINPUT1)
2127+
res1 = tf.image.non_max_suppression(boxes, scores, max_output_size=int(box_num / 2))
2128+
res2 = tf.image.non_max_suppression(boxes, scores, max_output_size=0)
2129+
_ = tf.identity(res1, name=_TFOUTPUT)
2130+
_ = tf.identity(res2, name=_TFOUTPUT1)
2131+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2132+
2133+
21202134
if __name__ == '__main__':
21212135
unittest_main()

tf2onnx/handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# pylint: disable=unused-argument,missing-docstring,invalid-name
1717

18+
1819
class tf_op:
1920
"""Class to implement the decorator to register handlers that map tf to onnx."""
2021

tf2onnx/onnx_opset/tensor.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
# pylint: disable=unused-argument,missing-docstring,unused-variable
2626

27+
2728
def _convert_shapenode_to_int64(ctx, node, input_number):
2829
"""cast int32 shape into int64 shape."""
2930
name = node.input[input_number]
@@ -55,6 +56,7 @@ def _wrap_concat_with_cast(ctx, node):
5556
ctx.set_dtype(output_cast.output[0], dtype)
5657
ctx.copy_shape(output_name, output_cast.output[0])
5758

59+
5860
@tf_op(["Size", "Flatten", "Dropout"])
5961
class DirectOp:
6062
@classmethod
@@ -296,8 +298,6 @@ def version_4(cls, ctx, node, **kwargs):
296298
ctx.remove_input(node, node.input[2])
297299
node.set_attr("axis", axis)
298300

299-
INT64_MAX = np.iinfo(np.int64).max
300-
301301

302302
def _make_gathernd_inner_loop(ctx, params, index, dtype):
303303
"""create the inner loop for GatherNd."""
@@ -343,7 +343,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
343343
# reshape indices into [sum(indices[:-1]), indices[-1]]
344344
indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
345345
indices_size = ctx.make_node("Size", [indices])
346-
attr = {"axes": [0], "ends": [INT64_MAX], "starts": [-1]}
346+
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]}
347347
inputs_map = {"data": indices_shape.output[0], **attr}
348348
inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
349349
outter_shape = ctx.make_node("Div",
@@ -401,7 +401,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
401401
[inner_loop_shape.output[0], one_const.output[0]],
402402
attr={"axis": 0},
403403
dtypes=[TensorProto.INT64])
404-
attr = {"axes": [0], "ends": [INT64_MAX], "starts": [1]}
404+
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]}
405405
inputs_map = {"data": inner_loop_shape_.output[0], **attr}
406406
output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
407407
attr = {"axes": [0], "ends": [-1], "starts": [0]}
@@ -932,3 +932,31 @@ def version_10(cls, ctx, node, **kwargs):
932932
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
933933
if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
934934
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
935+
936+
937+
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression")
938+
class NonMaxSuppression:
939+
@classmethod
940+
def version_10(cls, ctx, node, **kwargs):
941+
# int32 = NonMaxSuppressionV2(T boxes, T scores, int32 max_output_size, T iou_threshold, T score_threshold)
942+
# int64 = NonMaxSuppression(T boxes, T scores, int64 max_output_size, T iou_threshold, T score_threshold),
943+
# T means float32 here, the last 3 params are optional
944+
# tf boxes is 2D ([boxes_num, 4]) while onnx is 3D ([num_batches, boxes_num, 4])
945+
# tf scores is 1D ([boxes_num])while onnx is 2D ([num_batches, num_classes, boxes_num])
946+
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
947+
# while tf's output is [num_selected_boxes]
948+
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
949+
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
950+
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
951+
# replace original node with nonmaxsurppress + slice + squeeze +cast
952+
dtypes = [ctx.get_dtype(node.output[0])]
953+
shapes = [ctx.get_shape(node.output[0])]
954+
ctx.remove_node(node.name)
955+
new_nonmaxsurppress = ctx.make_node(node.type, node.input).output[0]
956+
attr = {"axes": [1], "ends": [3], "starts": [2]}
957+
inputs_map = {"data": new_nonmaxsurppress, **attr}
958+
slice_op = GraphBuilder(ctx).make_slice(inputs_map)
959+
squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]})
960+
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
961+
name=node.name, outputs=node.output, dtypes=dtypes, shapes=shapes)
962+

0 commit comments

Comments
 (0)