Skip to content

Commit 3f62b17

Browse files
authored
Merge pull request #8 from louis-she/add-prediction-threshold
change top-k to threshold
2 parents 59bc710 + 736495d commit 3f62b17

File tree

3 files changed

+50
-32
lines changed

3 files changed

+50
-32
lines changed

config.py.example

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Config(object):
3232
ANCHOR_STRIDE = [8, 16, 32, 64, 128]
3333
ANCHOR_SIZE = [32, 64, 128, 256, 512]
3434
NEG_POS_ANCHOR_NUM_RATIO = 3
35-
36-
# nms threshold
37-
NMS_THRESHOLD = 0.3
35+
36+
# prediction
37+
NMS_THRESHOLD = 0.3
38+
PREDICTION_THRESHOLD = 6

detector.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
class Detector(object):
1818

19-
def __init__(self, model, image_size=Config.IMAGE_SIZE, keep=200):
19+
def __init__(self, model, image_size=Config.IMAGE_SIZE, threshold=Config.PREDICTION_THRESHOLD):
2020
checkpoint = torch.load(seek_model(model))
2121
self.model = Net().to(device)
2222
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
23-
self.keep = keep
23+
self.threshold = threshold
2424
self.image_size = image_size
2525

2626
def infer(self, image):
@@ -39,10 +39,11 @@ def infer(self, image):
3939

4040
# get sorted indices by score
4141
diff = predictions[:, 5] - predictions[:, 4]
42-
scores, indices = torch.sort(diff, descending=True)
43-
# sort and slice predictions
44-
predictions = predictions[indices][:self.keep]
45-
scores = scores[:self.keep]
42+
scores, sorted_indices = torch.sort(diff, descending=True)
43+
valid_indices = scores > self.threshold
44+
scores = scores[valid_indices]
45+
46+
predictions = predictions[sorted_indices][valid_indices]
4647
# generate anchors then sort and slice
4748
anchor_configs = (
4849
Config.ANCHOR_STRIDE,
@@ -52,7 +53,7 @@ def infer(self, image):
5253
anchors = change_coordinate(np.vstack(
5354
list(map(lambda x: np.array(x), generate_anchors(*anchor_configs)))
5455
))
55-
anchors = torch.tensor(anchors[indices][:self.keep]).float().to(device)
56+
anchors = torch.tensor(anchors)[sorted_indices][valid_indices].float().to(device)
5657

5758
x = (predictions[:, 0] * anchors[:, 2] + anchors[:, 0]) * scale[1]
5859
y = (predictions[:, 1] * anchors[:, 3] + anchors[:, 1]) * scale[0]
@@ -62,10 +63,11 @@ def infer(self, image):
6263
bounding_boxes = torch.stack((x, y, w, h), dim=1).cpu().data.numpy()
6364
bounding_boxes = change_coordinate_inv(bounding_boxes)
6465
scores = scores.cpu().data.numpy()
65-
bboxes_scores = np.hstack((bounding_boxes,np.array([scores]).T))
66-
# TODO: do non-maximum suppression for bounding_boxes here
66+
bboxes_scores = np.hstack((bounding_boxes, np.array([scores]).T))
67+
68+
# nms
6769
keep = nms(bboxes_scores)
68-
70+
6971
return bounding_boxes[keep]
7072

7173
def main(args):

inference.ipynb

Lines changed: 34 additions & 19 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)