55
66
77class DetectionMetricsFactory :
8+ """Factory class for computing detection metrics including precision, recall, AP, and mAP.
9+
10+ :param iou_threshold: IoU threshold for matching predictions to ground truth, defaults to 0.5
11+ :type iou_threshold: float, optional
12+ :param num_classes: Number of classes in the dataset, defaults to None
13+ :type num_classes: Optional[int], optional
14+ """
15+
816 def __init__ (self , iou_threshold : float = 0.5 , num_classes : Optional [int ] = None ):
917 self .iou_threshold = iou_threshold
1018 self .num_classes = num_classes
@@ -15,14 +23,18 @@ def __init__(self, iou_threshold: float = 0.5, num_classes: Optional[int] = None
1523 ) # List of (gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores)
1624
1725 def update (self , gt_boxes , gt_labels , pred_boxes , pred_labels , pred_scores ):
18- """
19- Add a batch of predictions and ground truths.
20-
21- :param gt_boxes: List[ndarray], shape (num_gt, 4)
22- :param gt_labels: List[int]
23- :param pred_boxes: List[ndarray], shape (num_pred, 4)
24- :param pred_labels: List[int]
25- :param pred_scores: List[float]
26+ """Add a batch of predictions and ground truths.
27+
28+ :param gt_boxes: Ground truth bounding boxes, shape (num_gt, 4)
29+ :type gt_boxes: List[ndarray]
30+ :param gt_labels: Ground truth class labels
31+ :type gt_labels: List[int]
32+ :param pred_boxes: Predicted bounding boxes, shape (num_pred, 4)
33+ :type pred_boxes: List[ndarray]
34+ :param pred_labels: Predicted class labels
35+ :type pred_labels: List[int]
36+ :param pred_scores: Prediction confidence scores
37+ :type pred_scores: List[float]
2638 """
2739
2840 # Convert torch tensors to numpy
@@ -74,14 +86,22 @@ def _match_predictions(
7486 pred_scores : List [float ],
7587 iou_threshold : Optional [float ] = None ,
7688 ) -> Dict [int , List [Tuple [float , int ]]]:
77- """
78- Match predictions to ground truth and return per-class TP/FP flags with scores.
79-
80- Args:
81- iou_threshold: If provided, overrides self.iou_threshold
82-
83- Returns:
84- Dict[label_id, List[(score, tp_or_fp)]]
89+ """Match predictions to ground truth and return per-class TP/FP flags with scores.
90+
91+ :param gt_boxes: Ground truth bounding boxes, shape (num_gt, 4)
92+ :type gt_boxes: np.ndarray
93+ :param gt_labels: Ground truth class labels
94+ :type gt_labels: List[int]
95+ :param pred_boxes: Predicted bounding boxes, shape (num_pred, 4)
96+ :type pred_boxes: np.ndarray
97+ :param pred_labels: Predicted class labels
98+ :type pred_labels: List[int]
99+ :param pred_scores: Prediction confidence scores
100+ :type pred_scores: List[float]
101+ :param iou_threshold: IoU threshold for matching, overrides self.iou_threshold if provided, defaults to None
102+ :type iou_threshold: Optional[float], optional
103+ :return: Dictionary mapping class labels to list of (score, tp_or_fp) tuples
104+ :rtype: Dict[int, List[Tuple[float, int]]]
85105 """
86106 if iou_threshold is None :
87107 iou_threshold = self .iou_threshold
@@ -119,11 +139,10 @@ def _match_predictions(
119139 return results
120140
121141 def compute_metrics (self ) -> Dict [int , Dict [str , float ]]:
122- """
123- Compute per-class precision, recall, AP, and mAP.
142+ """Compute per-class precision, recall, AP, and mAP.
124143
125- Returns:
126- Dict[class_id , Dict[str, float]], plus an entry for mAP under key -1
144+ :return: Dictionary mapping class IDs to metric dictionaries, plus mAP under key -1
145+ :rtype: Dict[int , Dict[str, float]]
127146 """
128147 metrics = {}
129148 ap_values = []
@@ -164,11 +183,10 @@ def compute_metrics(self) -> Dict[int, Dict[str, float]]:
164183 return metrics
165184
166185 def compute_coco_map (self ) -> float :
167- """
168- Compute COCO-style mAP (mean AP over IoU thresholds 0.5:0.05:0.95).
186+ """Compute COCO-style mAP (mean AP over IoU thresholds 0.5:0.05:0.95).
169187
170- Returns:
171- float: mAP@[0.5:0.95]
188+ :return: mAP@[0.5:0.95]
189+ :rtype: float
172190 """
173191 iou_thresholds = np .arange (0.5 , 1.0 , 0.05 )
174192 aps = []
@@ -240,11 +258,10 @@ def compute_coco_map(self) -> float:
240258 return np .mean (aps ) if aps else 0.0
241259
242260 def get_overall_precision_recall_curve (self ) -> Dict [str , List [float ]]:
243- """
244- Get overall precision-recall curve data (aggregated across all classes).
261+ """Get overall precision-recall curve data (aggregated across all classes).
245262
246- Returns:
247- Dict[str, List[float]] with keys 'precision' and 'recall'
263+ :return: Dictionary with 'precision' and 'recall' keys containing curve data
264+ :rtype: Dict[str, List[float]]
248265 """
249266 all_detections = []
250267
@@ -274,11 +291,10 @@ def get_overall_precision_recall_curve(self) -> Dict[str, List[float]]:
274291 }
275292
276293 def compute_auc_pr (self ) -> float :
277- """
278- Compute the Area Under the Precision-Recall Curve (AUC-PR).
294+ """Compute the Area Under the Precision-Recall Curve (AUC-PR).
279295
280- Returns:
281- float: Area under the precision-recall curve
296+ :return: Area under the precision-recall curve
297+ :rtype: float
282298 """
283299 curve_data = self .get_overall_precision_recall_curve ()
284300 precision = np .array (curve_data ["precision" ])
@@ -299,10 +315,12 @@ def compute_auc_pr(self) -> float:
299315 return float (auc )
300316
301317 def get_metrics_dataframe (self , ontology : dict ) -> pd .DataFrame :
302- """
303- Get results as a pandas DataFrame.
318+ """Get results as a pandas DataFrame.
304319
305320 :param ontology: Mapping from class name → { "idx": int }
321+ :type ontology: dict
322+ :return: DataFrame with metrics as rows and classes as columns (with mean)
323+ :rtype: pd.DataFrame
306324 """
307325 all_metrics = self .compute_metrics ()
308326 # Build a dict: metric -> {class_name: value}
@@ -340,8 +358,14 @@ def get_metrics_dataframe(self, ontology: dict) -> pd.DataFrame:
340358
341359
342360def compute_iou_matrix (pred_boxes : np .ndarray , gt_boxes : np .ndarray ) -> np .ndarray :
343- """
344- Compute IoU matrix between pred and gt boxes.
361+ """Compute IoU matrix between pred and gt boxes.
362+
363+ :param pred_boxes: Predicted bounding boxes, shape (num_pred, 4)
364+ :type pred_boxes: np.ndarray
365+ :param gt_boxes: Ground truth bounding boxes, shape (num_gt, 4)
366+ :type gt_boxes: np.ndarray
367+ :return: IoU matrix with shape (num_pred, num_gt)
368+ :rtype: np.ndarray
345369 """
346370 iou_matrix = np .zeros ((len (pred_boxes ), len (gt_boxes )))
347371 for i , pb in enumerate (pred_boxes ):
@@ -351,6 +375,15 @@ def compute_iou_matrix(pred_boxes: np.ndarray, gt_boxes: np.ndarray) -> np.ndarr
351375
352376
353377def compute_iou (boxA , boxB ):
378+ """Compute Intersection over Union (IoU) between two bounding boxes.
379+
380+ :param boxA: First bounding box [x1, y1, x2, y2]
381+ :type boxA: array-like
382+ :param boxB: Second bounding box [x1, y1, x2, y2]
383+ :type boxB: array-like
384+ :return: IoU value between 0 and 1
385+ :rtype: float
386+ """
354387 xA = max (boxA [0 ], boxB [0 ])
355388 yA = max (boxA [1 ], boxB [1 ])
356389 xB = min (boxA [2 ], boxB [2 ])
@@ -364,6 +397,17 @@ def compute_iou(boxA, boxB):
364397
365398
366399def compute_ap (tps , fps , fn ):
400+ """Compute Average Precision (AP) using VOC-style 11-point interpolation.
401+
402+ :param tps: List of true positive flags
403+ :type tps: List[bool] or np.ndarray
404+ :param fps: List of false positive flags
405+ :type fps: List[bool] or np.ndarray
406+ :param fn: Number of false negatives
407+ :type fn: int
408+ :return: Tuple of (AP, precision array, recall array)
409+ :rtype: Tuple[float, np.ndarray, np.ndarray]
410+ """
367411 tps = np .array (tps , dtype = np .float32 )
368412 fps = np .array (fps , dtype = np .float32 )
369413
0 commit comments