1616
1717class Detector (object ):
1818
19- def __init__ (self , model , image_size = Config .IMAGE_SIZE , keep = 200 ):
19+ def __init__ (self , model , image_size = Config .IMAGE_SIZE , threshold = Config . PREDICTION_THRESHOLD ):
2020 checkpoint = torch .load (seek_model (model ))
2121 self .model = Net ().to (device )
2222 self .model .load_state_dict (checkpoint ['state_dict' ], strict = True )
23- self .keep = keep
23+ self .threshold = threshold
2424 self .image_size = image_size
2525
2626 def infer (self , image ):
@@ -39,10 +39,11 @@ def infer(self, image):
3939
4040 # get sorted indices by score
4141 diff = predictions [:, 5 ] - predictions [:, 4 ]
42- scores , indices = torch .sort (diff , descending = True )
43- # sort and slice predictions
44- predictions = predictions [indices ][:self .keep ]
45- scores = scores [:self .keep ]
42+ scores , sorted_indices = torch .sort (diff , descending = True )
43+ valid_indices = scores > self .threshold
44+ scores = scores [valid_indices ]
45+
46+ predictions = predictions [sorted_indices ][valid_indices ]
4647 # generate anchors then sort and slice
4748 anchor_configs = (
4849 Config .ANCHOR_STRIDE ,
@@ -52,7 +53,7 @@ def infer(self, image):
5253 anchors = change_coordinate (np .vstack (
5354 list (map (lambda x : np .array (x ), generate_anchors (* anchor_configs )))
5455 ))
55- anchors = torch .tensor (anchors [ indices ][: self . keep ]) .float ().to (device )
56+ anchors = torch .tensor (anchors )[ sorted_indices ][ valid_indices ] .float ().to (device )
5657
5758 x = (predictions [:, 0 ] * anchors [:, 2 ] + anchors [:, 0 ]) * scale [1 ]
5859 y = (predictions [:, 1 ] * anchors [:, 3 ] + anchors [:, 1 ]) * scale [0 ]
@@ -62,10 +63,11 @@ def infer(self, image):
6263 bounding_boxes = torch .stack ((x , y , w , h ), dim = 1 ).cpu ().data .numpy ()
6364 bounding_boxes = change_coordinate_inv (bounding_boxes )
6465 scores = scores .cpu ().data .numpy ()
65- bboxes_scores = np .hstack ((bounding_boxes ,np .array ([scores ]).T ))
66- # TODO: do non-maximum suppression for bounding_boxes here
66+ bboxes_scores = np .hstack ((bounding_boxes , np .array ([scores ]).T ))
67+
68+ # nms
6769 keep = nms (bboxes_scores )
68-
70+
6971 return bounding_boxes [keep ]
7072
7173def main (args ):
0 commit comments