Skip to content

Commit 2f41ae4

Browse files
author
Berkay Ugur Senocak
committed
bugfix & refactor
1 parent cd235e7 commit 2f41ae4

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

confusion_matrix.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,66 +19,62 @@ def box_iou_calc(boxes1, boxes2):
1919
def box_area(box):
2020
# box = 4xn
2121
return (box[2] - box[0]) * (box[3] - box[1])
22-
2322

2423
area1 = box_area(boxes1.T)
2524
area2 = box_area(boxes2.T)
2625

2726
lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
2827
rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
2928

30-
inter = np.prod(np.clip(rb - lt, a_min = 0, a_max = None), 2)
29+
inter = np.prod(np.clip(rb - lt, a_min=0, a_max=None), 2)
3130
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
3231

3332

3433
class ConfusionMatrix:
35-
def __init__(self, num_classes, CONF_THRESHOLD = 0.3, IOU_THRESHOLD = 0.5):
34+
def __init__(self, num_classes, CONF_THRESHOLD=0.3, IOU_THRESHOLD=0.5):
3635
self.matrix = np.zeros((num_classes + 1, num_classes + 1))
3736
self.num_classes = num_classes
3837
self.CONF_THRESHOLD = CONF_THRESHOLD
3938
self.IOU_THRESHOLD = IOU_THRESHOLD
40-
39+
4140
def process_batch(self, detections, labels):
42-
'''
41+
"""
4342
Return intersection-over-union (Jaccard index) of boxes.
4443
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
4544
Arguments:
4645
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
4746
labels (Array[M, 5]), class, x1, y1, x2, y2
4847
Returns:
4948
None, updates confusion matrix accordingly
50-
'''
49+
"""
5150
detections = detections[detections[:, 4] > self.CONF_THRESHOLD]
5251
gt_classes = labels[:, 0].astype(np.int16)
5352
detection_classes = detections[:, 5].astype(np.int16)
5453

5554
all_ious = box_iou_calc(labels[:, 1:], detections[:, :4])
5655
want_idx = np.where(all_ious > self.IOU_THRESHOLD)
5756

58-
all_matches = []
59-
for i in range(want_idx[0].shape[0]):
60-
all_matches.append([want_idx[0][i], want_idx[1][i], all_ious[want_idx[0][i], want_idx[1][i]]])
61-
57+
all_matches = [[want_idx[0][i], want_idx[1][i], all_ious[want_idx[0][i], want_idx[1][i]]]
58+
for i in range(want_idx[0].shape[0])]
59+
6260
all_matches = np.array(all_matches)
63-
if all_matches.shape[0] > 0: # if there is match
61+
if all_matches.shape[0] > 0: # if there is match
6462
all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
6563

66-
all_matches = all_matches[np.unique(all_matches[:, 1], return_index = True)[1]]
64+
all_matches = all_matches[np.unique(all_matches[:, 1], return_index=True)[1]]
6765

6866
all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
6967

70-
all_matches = all_matches[np.unique(all_matches[:, 0], return_index = True)[1]]
71-
68+
all_matches = all_matches[np.unique(all_matches[:, 0], return_index=True)[1]]
7269

7370
for i, label in enumerate(labels):
71+
gt_class = gt_classes[i]
7472
if all_matches.shape[0] > 0 and all_matches[all_matches[:, 0] == i].shape[0] == 1:
75-
gt_class = gt_classes[i]
7673
detection_class = detection_classes[int(all_matches[all_matches[:, 0] == i, 1][0])]
77-
self.matrix[(gt_class), detection_class] += 1
74+
self.matrix[detection_class, gt_class] += 1
7875
else:
79-
gt_class = gt_classes[i]
80-
self.matrix[self.num_classes, (gt_class)] += 1
81-
76+
self.matrix[self.num_classes, gt_class] += 1
77+
8278
for i, detection in enumerate(detections):
8379
if all_matches.shape[0] and all_matches[all_matches[:, 1] == i].shape[0] == 0:
8480
detection_class = detection_classes[i]
@@ -90,4 +86,3 @@ def return_matrix(self):
9086
def print_matrix(self):
9187
for i in range(self.num_classes + 1):
9288
print(' '.join(map(str, self.matrix[i])))
93-

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
numpy

0 commit comments

Comments
 (0)