@@ -45,9 +45,17 @@ def test_postprocess_model3(self):
45
45
def test_postprocess_model4 (self ):
46
46
self ._test_postprocess (num_classes = 5 , num_boxes = 99 , detections_per_class = 2 , max_detections = 20 , extra_class = True )
47
47
48
- def _test_postprocess (self , num_classes , num_boxes , detections_per_class , max_detections , extra_class = False ):
48
+ @requires_tflite ("TFLite_Detection_PostProcess" )
49
+ @check_opset_min_version (11 , "Pad" )
50
+ def test_postprocess_model5 (self ):
51
+ self ._test_postprocess (num_classes = 1 , num_boxes = 100 , detections_per_class = 0 ,
52
+ max_detections = 50 , use_regular_nms = False )
53
+
54
+ def _test_postprocess (self , num_classes , num_boxes , detections_per_class ,
55
+ max_detections , extra_class = False , use_regular_nms = True ):
49
56
model = self .make_postprocess_model (num_classes = num_classes , detections_per_class = detections_per_class ,
50
- max_detections = max_detections , x_scale = 11.0 , w_scale = 6.0 )
57
+ max_detections = max_detections , x_scale = 11.0 , w_scale = 6.0 ,
58
+ use_regular_nms = use_regular_nms )
51
59
52
60
np .random .seed (42 )
53
61
box_encodings_val = np .random .random_sample ([1 , num_boxes , 4 ]).astype (np .float32 )
0 commit comments