Skip to content

Commit 5df948f

Browse files
authored
Merge pull request #14 from kaanakan/dev
empty detections bugfix
2 parents 6af4502 + 2007809 commit 5df948f

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

confusion_matrix.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def box_area(box):
3131

3232

3333
class ConfusionMatrix:
34-
def __init__(self, num_classes, CONF_THRESHOLD=0.3, IOU_THRESHOLD=0.5):
34+
def __init__(self, num_classes: int, CONF_THRESHOLD=0.3, IOU_THRESHOLD=0.5):
3535
self.matrix = np.zeros((num_classes + 1, num_classes + 1))
3636
self.num_classes = num_classes
3737
self.CONF_THRESHOLD = CONF_THRESHOLD
3838
self.IOU_THRESHOLD = IOU_THRESHOLD
3939

40-
def process_batch(self, detections, labels):
40+
def process_batch(self, detections, labels: np.ndarray):
4141
"""
4242
Return intersection-over-union (Jaccard index) of boxes.
4343
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
@@ -47,8 +47,17 @@ def process_batch(self, detections, labels):
4747
Returns:
4848
None, updates confusion matrix accordingly
4949
"""
50-
detections = detections[detections[:, 4] > self.CONF_THRESHOLD]
5150
gt_classes = labels[:, 0].astype(np.int16)
51+
52+
try:
53+
detections = detections[detections[:, 4] > self.CONF_THRESHOLD]
54+
except IndexError or TypeError:
55+
# detections are empty, end of process
56+
for i, label in enumerate(labels):
57+
gt_class = gt_classes[i]
58+
self.matrix[self.num_classes, gt_class] += 1
59+
return
60+
5261
detection_classes = detections[:, 5].astype(np.int16)
5362

5463
all_ious = box_iou_calc(labels[:, 1:], detections[:, :4])

0 commit comments

Comments
 (0)