Skip to content

Commit 675f6ef

Browse files
committed
1. classification should be sigmoid 2. NMS for each class
1 parent 9904574 commit 675f6ef

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ runs
1616
log
1717

1818
*.jpg
19+
*.json
1920
data/outcome

run_coco.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from tool.torch_utils import *
5+
from models import *
6+
7+
import sys
8+
import cv2
9+
from tool.utils import load_class_names, plot_boxes_cv2
10+
11+
12+
n_classes = 80
13+
cocoImageListFileName = "/home/erics/MS_COCO/val2017.txt"
14+
cocoClassIDFileName = "/home/erics/yolo_cpp_standalone/data/categories.txt"
15+
cocoClassNamesFileName = "/home/erics/yolo_cpp_standalone/data/coco.names"
16+
17+
model = Yolov4(yolov4conv137weight=None, n_classes=n_classes, inference=True)
18+
19+
pretrained_dict = torch.load(weightfile, map_location=torch.device('cuda'))
20+
model.load_state_dict(pretrained_dict)
21+
22+
23+
boxes = do_detect(model, sized, 0.4, 0.6, use_cuda)

tool/utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def post_processing(img, conf_thresh, nms_thresh, output):
182182
box_array = box_array.cpu().detach().numpy()
183183
confs = confs.cpu().detach().numpy()
184184

185+
num_classes = confs.shape[2]
186+
185187
# [batch, num, 4]
186188
box_array = box_array[:, :, 0]
187189

@@ -199,16 +201,24 @@ def post_processing(img, conf_thresh, nms_thresh, output):
199201
l_max_conf = max_conf[i, argwhere]
200202
l_max_id = max_id[i, argwhere]
201203

202-
keep = nms_cpu(l_box_array, l_max_conf, nms_thresh)
203-
204204
bboxes = []
205-
if (keep.size > 0):
206-
l_box_array = l_box_array[keep, :]
207-
l_max_conf = l_max_conf[keep]
208-
l_max_id = l_max_id[keep]
209-
210-
for j in range(l_box_array.shape[0]):
211-
bboxes.append([l_box_array[j, 0], l_box_array[j, 1], l_box_array[j, 2], l_box_array[j, 3], l_max_conf[j], l_max_conf[j], l_max_id[j]])
205+
# nms for each class
206+
for j in range(num_classes):
207+
208+
cls_argwhere = l_max_id == j
209+
ll_box_array = l_box_array[cls_argwhere, :]
210+
ll_max_conf = l_max_conf[cls_argwhere]
211+
ll_max_id = l_max_id[cls_argwhere]
212+
213+
keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
214+
215+
if (keep.size > 0):
216+
ll_box_array = ll_box_array[keep, :]
217+
ll_max_conf = ll_max_conf[keep]
218+
ll_max_id = ll_max_id[keep]
219+
220+
for k in range(ll_box_array.shape[0]):
221+
bboxes.append([ll_box_array[k, 0], ll_box_array[k, 1], ll_box_array[k, 2], ll_box_array[k, 3], ll_max_conf[k], ll_max_conf[k], ll_max_id[k]])
212222

213223
bboxes_batch.append(bboxes)
214224

tool/yolo_layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tool.torch_utils import *
44

55

6-
def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1,
6+
def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
77
validation=False):
88
# Output would be invalid if it does not satisfy this assert
99
# assert (output.size(1) == (5 + num_classes) * num_anchors)
@@ -80,7 +80,7 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
8080
bxy = torch.sigmoid(bxy)
8181
bwh = torch.exp(bwh)
8282
det_confs = torch.sigmoid(det_confs)
83-
cls_confs = torch.nn.Softmax(dim=2)(cls_confs)
83+
cls_confs = torch.sigmoid(cls_confs)
8484

8585
# Shape: [batch, num_anchors, 2, H * W]
8686
bxy = bxy.view(batch, num_anchors, 2, H * W)
@@ -165,7 +165,7 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
165165
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
166166
bwh = torch.exp(bwh)
167167
det_confs = torch.sigmoid(det_confs)
168-
cls_confs = torch.nn.Softmax(dim=2)(cls_confs)
168+
cls_confs = torch.sigmoid(cls_confs)
169169

170170
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
171171
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0), axis=0)

0 commit comments

Comments
 (0)