Skip to content

Commit adb3fff

Browse files
committed
assert soft_nms_sigma == 0
1 parent 1d5b9aa commit adb3fff

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

tests/test_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2608,12 +2608,13 @@ def test_non_max_suppression_v4_no_padding(self):
26082608
scores_val = np.random.random_sample([box_num]).astype(np.float32)
26092609

26102610
def func(boxes, scores):
2611-
ret1, ret2 = tf.image.non_max_suppression_padded(boxes, scores, max_output_size=int(box_num / 2),
2612-
pad_to_max_output_size=True)
2611+
ret1, ret2 = tf.image.non_max_suppression_padded(boxes, scores, max_output_size=int(box_num),
2612+
pad_to_max_output_size=False)
26132613
return tf.identity(ret1, name=_TFOUTPUT), tf.identity(ret2, name=_TFOUTPUT1)
26142614

26152615
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
26162616

2617+
@check_tf_min_version("1.15")
26172618
@check_opset_min_version(11, "NonMaxSuppressionV5")
26182619
def test_non_max_suppression_v5(self):
26192620
box_num = 10

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1471,13 +1471,14 @@ def version_10(cls, ctx, node, **kwargs):
14711471
dtypes = [[ctx.get_dtype(output)] for output in node.output]
14721472
shapes = [[ctx.get_shape(output)] for output in node.output]
14731473
max_output_size = node.input[2]
1474+
utils.make_sure(len(node.inputs) <= 5 or int(node.inputs[5].get_tensor_value(False)) == 0,
1475+
"soft_nms_sigma must be 0")
14741476
ctx.remove_node(node.name)
14751477
new_nonmaxsurppress = ctx.make_node(node.type, node.input[: 5]).output[0]
14761478
slice_op = GraphBuilder(ctx).make_slice({"data": new_nonmaxsurppress,
14771479
"axes": [1], "ends": [3], "starts": [2]})
14781480
squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]})
14791481
if len(node.input) > 5: # V5
1480-
logger.warning("NonMaxSuppressionV5 only parltially supported, soft_nms_sigma must be 0.0")
14811482
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
14821483
outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0])
14831484
ctx.make_node("Gather", inputs=[input_score.input[0], squeeze_op.output[0]],

0 commit comments

Comments
 (0)