Skip to content

Commit 841be50

Browse files
committed
Inference output as 2 tensors: [b, n, 1, 4] and [b, n, cls]
1 parent 56227b1 commit 841be50

File tree

8 files changed

+43
-32
lines changed

8 files changed

+43
-32
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@ you can use darknet2pytorch to convert it yourself, or download my converted mod
109109
110110
- Inference output
111111
112-
Inference output is of shape `[batch, num_boxes, 4 + num_classes]` in which `[batch, num_boxes, 4]` is x_center, y_center, width, height of bounding boxes, and `[batch, num_boxes, num_classes]` is confidences of bounding box for all classes.
112+
There are 2 inference outputs.
113+
- One is locations of bounding boxes, its shape is `[batch, num_boxes, 1, 4]` which represents x1, y1, x2, y2 of each bounding box.
114+
- The other one is scores of bounding boxes which is of shape `[batch, num_boxes, num_classes]` indicating scores of all classes for each bounding box.
113115
114116
Until now, still a small piece of post-processing including NMS is required. We are trying to minimize time and complexity of post-processing.
115117
116118
117-
118119
# 3. Darknet2ONNX (Evolving)
119120
120121
- **This script is to convert the official pretrained darknet model into ONNX**

demo_darknet2onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def detect(session, image_src):
4343

4444
outputs = session.run(None, {input_name: img_in})
4545

46-
boxes = post_processing(img_in, 0.4, 0.6, outputs[0])
46+
boxes = post_processing(img_in, 0.4, 0.6, outputs)
4747

4848
num_classes = 80
4949
if num_classes == 20:

demo_pytorch2onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W
3131
export_params=True,
3232
opset_version=11,
3333
do_constant_folding=True,
34-
input_names=['input'], output_names=['output'],
34+
input_names=['input'], output_names=['boxes', 'confs'],
3535
dynamic_axes=None)
3636

3737
print('Onnx model exporting done')

demo_trt.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,16 @@ def detect(context, buffers, image_src, image_size, num_classes):
162162

163163
print('Len of outputs: ', len(trt_outputs))
164164

165-
trt_output = trt_outputs[0].reshape(1, -1, 4 + num_classes)
165+
trt_outputs[0] = trt_outputs[0].reshape(1, -1, 1, 4)
166+
trt_outputs[1] = trt_outputs[1].reshape(1, -1, num_classes)
166167

167168
tb = time.time()
168169

169-
print(trt_output.shape)
170-
171170
print('-----------------------------------')
172171
print(' TRT inference time: %f' % (tb - ta))
173172
print('-----------------------------------')
174173

175-
boxes = post_processing(img_in, 0.4, 0.6, trt_output)
174+
boxes = post_processing(img_in, 0.4, 0.6, trt_outputs)
176175

177176
return boxes
178177

tool/darknet2onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
2424
export_params=True,
2525
opset_version=11,
2626
do_constant_folding=True,
27-
input_names=['input'], output_names=['output'],
27+
input_names=['input'], output_names=['boxes', 'confs'],
2828
dynamic_axes=None)
2929

3030
print('Onnx model exporting done')

tool/torch_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@ def get_region_boxes(boxes_and_confs):
5656
boxes_list.append(item[0])
5757
confs_list.append(item[1])
5858

59-
# boxes: [batch, num1 + num2 + num3, 4]
59+
# boxes: [batch, num1 + num2 + num3, 1, 4]
6060
# confs: [batch, num1 + num2 + num3, num_classes]
6161
boxes = torch.cat(boxes_list, dim=1)
6262
confs = torch.cat(confs_list, dim=1)
63-
64-
output = torch.cat((boxes, confs), dim=2)
6563

66-
return output
64+
return [boxes, confs]
6765

6866

6967
def convert2cpu(gpu_matrix):

tool/utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False):
6262
# print(boxes.shape)
6363
x1 = boxes[:, 0]
6464
y1 = boxes[:, 1]
65-
x2 = boxes[:, 0] + boxes[:, 2]
66-
y2 = boxes[:, 1] + boxes[:, 3]
65+
x2 = boxes[:, 2]
66+
y2 = boxes[:, 3]
6767

6868
areas = (x2 - x1) * (y2 - y1)
6969
order = confs.argsort()[::-1]
@@ -113,10 +113,10 @@ def get_color(c, x, max_val):
113113
height = img.shape[0]
114114
for i in range(len(boxes)):
115115
box = boxes[i]
116-
x1 = int((box[0] - box[2] / 2.0) * width)
117-
y1 = int((box[1] - box[3] / 2.0) * height)
118-
x2 = int((box[0] + box[2] / 2.0) * width)
119-
y2 = int((box[1] + box[3] / 2.0) * height)
116+
x1 = int(box[0] * width)
117+
y1 = int(box[1] * height)
118+
x2 = int(box[2] * width)
119+
y2 = int(box[3] * height)
120120

121121
if color:
122122
rgb = color
@@ -171,16 +171,19 @@ def post_processing(img, conf_thresh, nms_thresh, output):
171171
# strides = [8, 16, 32]
172172
# anchor_step = len(anchors) // num_anchors
173173

174+
# [batch, num, 1, 4]
175+
box_array = output[0]
176+
# [batch, num, num_classes]
177+
confs = output[1]
178+
174179
t1 = time.time()
175180

176-
if type(output).__name__ != 'ndarray':
177-
output = output.cpu().detach().numpy()
181+
if type(box_array).__name__ != 'ndarray':
182+
box_array = box_array.cpu().detach().numpy()
183+
confs = confs.cpu().detach().numpy()
178184

179185
# [batch, num, 4]
180-
box_array = output[:, :, :4]
181-
182-
# [batch, num, num_classes]
183-
confs = output[:, :, 4:]
186+
box_array = box_array[:, :, 0]
184187

185188
# [batch, num, num_classes] --> [batch, num]
186189
max_conf = np.max(confs, axis=2)

tool/yolo_layer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,20 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
9494
print(anchor_tensor.size())
9595
bwh *= anchor_tensor
9696

97-
# Shape: [batch, num_anchors, 4, H * W] --> [batch, num_anchors * H * W, 4]
98-
boxes = torch.cat((bxy, bwh), dim=2).permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, 4)
97+
bx1y1 = bxy - bwh * 0.5
98+
bx2y2 = bxy + bwh
99+
100+
# Shape: [batch, num_anchors, 4, H * W] --> [batch, num_anchors * H * W, 1, 4]
101+
boxes = torch.cat((bx1y1, bx2y2), dim=2).permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, 1, 4)
102+
# boxes = boxes.repeat(1, 1, num_classes, 1)
99103

100104
print(normal_tensor.size())
101105
boxes *= normal_tensor
102106

103107
det_confs = det_confs.view(batch, num_anchors * H * W, 1)
104108
confs = cls_confs * det_confs
105109

106-
# boxes: [batch, num_anchors * H * W, 4]
110+
# boxes: [batch, num_anchors * H * W, 1, 4]
107111
# confs: [batch, num_anchors * H * W, num_classes]
108112

109113
return boxes, confs
@@ -231,17 +235,23 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
231235
bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
232236
bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
233237

234-
# Shape: [batch, num_anchors * h * w, 4]
235-
boxes = torch.cat((bx, by, bw, bh), dim=2).view(batch, num_anchors * H * W, 4)
238+
bx1 = bx - bw * 0.5
239+
by1 = by - bh * 0.5
240+
bx2 = bx1 + bw
241+
by2 = by1 + bh
242+
243+
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
244+
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(batch, num_anchors * H * W, 1, 4)
245+
# boxes = boxes.repeat(1, 1, num_classes, 1)
236246

237-
# boxes: [batch, num_anchors * H * W, num_classes, 4]
247+
# boxes: [batch, num_anchors * H * W, 1, 4]
238248
# cls_confs: [batch, num_anchors * H * W, num_classes]
239249
# det_confs: [batch, num_anchors * H * W]
240250

241251
det_confs = det_confs.view(batch, num_anchors * H * W, 1)
242252
confs = cls_confs * det_confs
243253

244-
# boxes: [batch, num_anchors * H * W, 4]
254+
# boxes: [batch, num_anchors * H * W, 1, 4]
245255
# confs: [batch, num_anchors * H * W, num_classes]
246256

247257
return boxes, confs

0 commit comments

Comments
 (0)