Skip to content

Commit c8d7a3b

Browse files
Implement non-standard nms for tflite (#1318)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d5267bd commit c8d7a3b

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tests/test_tflite_postprocess.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,17 @@ def test_postprocess_model3(self):
4545
def test_postprocess_model4(self):
4646
self._test_postprocess(num_classes=5, num_boxes=99, detections_per_class=2, max_detections=20, extra_class=True)
4747

48-
def _test_postprocess(self, num_classes, num_boxes, detections_per_class, max_detections, extra_class=False):
48+
@requires_tflite("TFLite_Detection_PostProcess")
49+
@check_opset_min_version(11, "Pad")
50+
def test_postprocess_model5(self):
51+
self._test_postprocess(num_classes=1, num_boxes=100, detections_per_class=0,
52+
max_detections=50, use_regular_nms=False)
53+
54+
def _test_postprocess(self, num_classes, num_boxes, detections_per_class,
55+
max_detections, extra_class=False, use_regular_nms=True):
4956
model = self.make_postprocess_model(num_classes=num_classes, detections_per_class=detections_per_class,
50-
max_detections=max_detections, x_scale=11.0, w_scale=6.0)
57+
max_detections=max_detections, x_scale=11.0, w_scale=6.0,
58+
use_regular_nms=use_regular_nms)
5159

5260
np.random.seed(42)
5361
box_encodings_val = np.random.random_sample([1, num_boxes, 4]).astype(np.float32)

tf2onnx/tflite_handlers/tfl_postprocess.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ def version_11(cls, ctx, node, **kwargs):
6969
score_threshold = np.array(node.get_attr_value('nms_score_threshold'), np.float32)
7070
score_threshold_const = ctx.make_const(utils.make_name('score_threshold'), score_threshold).output[0]
7171

72-
boxes_per_class = np.array(node.get_attr_value('detections_per_class', 100), np.int64)
72+
if node.get_attr_value('use_regular_nms', False):
73+
boxes_per_class = np.array(node.get_attr_value('detections_per_class', 100), np.int64)
74+
else:
75+
# When tflite uses FastNMS, detections_per_class is ignored.
76+
logging.warning("NMS node %s uses fast NMS. ONNX will approximate with standard NMS.", node.name)
77+
boxes_per_class = np.array(max_detections, np.int64)
7378
max_boxes_per_class_const = ctx.make_const(utils.make_name('max_boxes_per_class'), boxes_per_class).output[0]
7479

7580
# scores.shape = [batch_dim, classes_num, box_num]

0 commit comments

Comments
 (0)