@@ -31,13 +31,13 @@ def box_area(box):
3131
3232
3333class ConfusionMatrix :
34- def __init__ (self , num_classes , CONF_THRESHOLD = 0.3 , IOU_THRESHOLD = 0.5 ):
34+ def __init__ (self , num_classes : int , CONF_THRESHOLD = 0.3 , IOU_THRESHOLD = 0.5 ):
3535 self .matrix = np .zeros ((num_classes + 1 , num_classes + 1 ))
3636 self .num_classes = num_classes
3737 self .CONF_THRESHOLD = CONF_THRESHOLD
3838 self .IOU_THRESHOLD = IOU_THRESHOLD
3939
40- def process_batch (self , detections , labels ):
40+ def process_batch (self , detections , labels : np . ndarray ):
4141 """
4242 Return intersection-over-union (Jaccard index) of boxes.
4343 Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
@@ -47,8 +47,17 @@ def process_batch(self, detections, labels):
4747 Returns:
4848 None, updates confusion matrix accordingly
4949 """
50- detections = detections [detections [:, 4 ] > self .CONF_THRESHOLD ]
5150 gt_classes = labels [:, 0 ].astype (np .int16 )
51+
52+ try :
53+ detections = detections [detections [:, 4 ] > self .CONF_THRESHOLD ]
54+ except IndexError or TypeError :
55+ # detections are empty, end of process
56+ for i , label in enumerate (labels ):
57+ gt_class = gt_classes [i ]
58+ self .matrix [self .num_classes , gt_class ] += 1
59+ return
60+
5261 detection_classes = detections [:, 5 ].astype (np .int16 )
5362
5463 all_ious = box_iou_calc (labels [:, 1 :], detections [:, :4 ])
0 commit comments