@@ -19,66 +19,62 @@ def box_iou_calc(boxes1, boxes2):
1919 def box_area (box ):
2020 # box = 4xn
2121 return (box [2 ] - box [0 ]) * (box [3 ] - box [1 ])
22-
2322
2423 area1 = box_area (boxes1 .T )
2524 area2 = box_area (boxes2 .T )
2625
2726 lt = np .maximum (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
2827 rb = np .minimum (boxes1 [:, None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
2928
30- inter = np .prod (np .clip (rb - lt , a_min = 0 , a_max = None ), 2 )
29+ inter = np .prod (np .clip (rb - lt , a_min = 0 , a_max = None ), 2 )
3130 return inter / (area1 [:, None ] + area2 - inter ) # iou = inter / (area1 + area2 - inter)
3231
3332
3433class ConfusionMatrix :
35- def __init__ (self , num_classes , CONF_THRESHOLD = 0.3 , IOU_THRESHOLD = 0.5 ):
34+ def __init__ (self , num_classes , CONF_THRESHOLD = 0.3 , IOU_THRESHOLD = 0.5 ):
3635 self .matrix = np .zeros ((num_classes + 1 , num_classes + 1 ))
3736 self .num_classes = num_classes
3837 self .CONF_THRESHOLD = CONF_THRESHOLD
3938 self .IOU_THRESHOLD = IOU_THRESHOLD
40-
39+
4140 def process_batch (self , detections , labels ):
42- '''
41+ """
4342 Return intersection-over-union (Jaccard index) of boxes.
4443 Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
4544 Arguments:
4645 detections (Array[N, 6]), x1, y1, x2, y2, conf, class
4746 labels (Array[M, 5]), class, x1, y1, x2, y2
4847 Returns:
4948 None, updates confusion matrix accordingly
50- '''
49+ """
5150 detections = detections [detections [:, 4 ] > self .CONF_THRESHOLD ]
5251 gt_classes = labels [:, 0 ].astype (np .int16 )
5352 detection_classes = detections [:, 5 ].astype (np .int16 )
5453
5554 all_ious = box_iou_calc (labels [:, 1 :], detections [:, :4 ])
5655 want_idx = np .where (all_ious > self .IOU_THRESHOLD )
5756
58- all_matches = []
59- for i in range (want_idx [0 ].shape [0 ]):
60- all_matches .append ([want_idx [0 ][i ], want_idx [1 ][i ], all_ious [want_idx [0 ][i ], want_idx [1 ][i ]]])
61-
57+ all_matches = [[want_idx [0 ][i ], want_idx [1 ][i ], all_ious [want_idx [0 ][i ], want_idx [1 ][i ]]]
58+ for i in range (want_idx [0 ].shape [0 ])]
59+
6260 all_matches = np .array (all_matches )
63- if all_matches .shape [0 ] > 0 : # if there is match
61+ if all_matches .shape [0 ] > 0 : # if there is match
6462 all_matches = all_matches [all_matches [:, 2 ].argsort ()[::- 1 ]]
6563
66- all_matches = all_matches [np .unique (all_matches [:, 1 ], return_index = True )[1 ]]
64+ all_matches = all_matches [np .unique (all_matches [:, 1 ], return_index = True )[1 ]]
6765
6866 all_matches = all_matches [all_matches [:, 2 ].argsort ()[::- 1 ]]
6967
70- all_matches = all_matches [np .unique (all_matches [:, 0 ], return_index = True )[1 ]]
71-
68+ all_matches = all_matches [np .unique (all_matches [:, 0 ], return_index = True )[1 ]]
7269
7370 for i , label in enumerate (labels ):
71+ gt_class = gt_classes [i ]
7472 if all_matches .shape [0 ] > 0 and all_matches [all_matches [:, 0 ] == i ].shape [0 ] == 1 :
75- gt_class = gt_classes [i ]
7673 detection_class = detection_classes [int (all_matches [all_matches [:, 0 ] == i , 1 ][0 ])]
77- self .matrix [( gt_class ), detection_class ] += 1
74+ self .matrix [detection_class , gt_class ] += 1
7875 else :
79- gt_class = gt_classes [i ]
80- self .matrix [self .num_classes , (gt_class )] += 1
81-
76+ self .matrix [self .num_classes , gt_class ] += 1
77+
8278 for i , detection in enumerate (detections ):
8379 if all_matches .shape [0 ] and all_matches [all_matches [:, 1 ] == i ].shape [0 ] == 0 :
8480 detection_class = detection_classes [i ]
@@ -90,4 +86,3 @@ def return_matrix(self):
9086 def print_matrix (self ):
9187 for i in range (self .num_classes + 1 ):
9288 print (' ' .join (map (str , self .matrix [i ])))
93-
0 commit comments