Skip to content

Commit 8aa4631

Browse files
Made the requested changes
1 parent 93e941e commit 8aa4631

File tree

2 files changed

+84
-38
lines changed

2 files changed

+84
-38
lines changed

app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def browse_folder():
6060
folder = result.stdout.strip()
6161
if folder:
6262
return folder
63-
except Exception:
63+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
6464
continue
6565
return None
6666
except Exception:
@@ -128,6 +128,8 @@ def browse_folder():
128128
st.rerun()
129129
elif folder is not None:
130130
st.warning("Selected path is not a valid folder.")
131+
else:
132+
st.warning("Could not open folder browser. Please enter the path manually")
131133

132134
if dataset_path_input != st.session_state.get("dataset_path", ""):
133135
st.session_state["dataset_path"] = dataset_path_input
@@ -193,7 +195,7 @@ def browse_folder():
193195
with col2:
194196
st.selectbox(
195197
"Device",
196-
["cpu", "gpu"],
198+
["cpu", "cuda", "mps"],
197199
index=0 if st.session_state.get("device", "cpu") == "cpu" else 1,
198200
key="device",
199201
)

detectionmetrics/utils/detection_metrics.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66

77
class DetectionMetricsFactory:
8+
"""Factory class for computing detection metrics including precision, recall, AP, and mAP.
9+
10+
:param iou_threshold: IoU threshold for matching predictions to ground truth, defaults to 0.5
11+
:type iou_threshold: float, optional
12+
:param num_classes: Number of classes in the dataset, defaults to None
13+
:type num_classes: Optional[int], optional
14+
"""
15+
816
def __init__(self, iou_threshold: float = 0.5, num_classes: Optional[int] = None):
917
self.iou_threshold = iou_threshold
1018
self.num_classes = num_classes
@@ -15,14 +23,18 @@ def __init__(self, iou_threshold: float = 0.5, num_classes: Optional[int] = None
1523
) # List of (gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores)
1624

1725
def update(self, gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores):
18-
"""
19-
Add a batch of predictions and ground truths.
20-
21-
:param gt_boxes: List[ndarray], shape (num_gt, 4)
22-
:param gt_labels: List[int]
23-
:param pred_boxes: List[ndarray], shape (num_pred, 4)
24-
:param pred_labels: List[int]
25-
:param pred_scores: List[float]
26+
"""Add a batch of predictions and ground truths.
27+
28+
:param gt_boxes: Ground truth bounding boxes, shape (num_gt, 4)
29+
:type gt_boxes: List[ndarray]
30+
:param gt_labels: Ground truth class labels
31+
:type gt_labels: List[int]
32+
:param pred_boxes: Predicted bounding boxes, shape (num_pred, 4)
33+
:type pred_boxes: List[ndarray]
34+
:param pred_labels: Predicted class labels
35+
:type pred_labels: List[int]
36+
:param pred_scores: Prediction confidence scores
37+
:type pred_scores: List[float]
2638
"""
2739

2840
# Convert torch tensors to numpy
@@ -74,14 +86,22 @@ def _match_predictions(
7486
pred_scores: List[float],
7587
iou_threshold: Optional[float] = None,
7688
) -> Dict[int, List[Tuple[float, int]]]:
77-
"""
78-
Match predictions to ground truth and return per-class TP/FP flags with scores.
79-
80-
Args:
81-
iou_threshold: If provided, overrides self.iou_threshold
82-
83-
Returns:
84-
Dict[label_id, List[(score, tp_or_fp)]]
89+
"""Match predictions to ground truth and return per-class TP/FP flags with scores.
90+
91+
:param gt_boxes: Ground truth bounding boxes, shape (num_gt, 4)
92+
:type gt_boxes: np.ndarray
93+
:param gt_labels: Ground truth class labels
94+
:type gt_labels: List[int]
95+
:param pred_boxes: Predicted bounding boxes, shape (num_pred, 4)
96+
:type pred_boxes: np.ndarray
97+
:param pred_labels: Predicted class labels
98+
:type pred_labels: List[int]
99+
:param pred_scores: Prediction confidence scores
100+
:type pred_scores: List[float]
101+
:param iou_threshold: IoU threshold for matching, overrides self.iou_threshold if provided, defaults to None
102+
:type iou_threshold: Optional[float], optional
103+
:return: Dictionary mapping class labels to list of (score, tp_or_fp) tuples
104+
:rtype: Dict[int, List[Tuple[float, int]]]
85105
"""
86106
if iou_threshold is None:
87107
iou_threshold = self.iou_threshold
@@ -119,11 +139,10 @@ def _match_predictions(
119139
return results
120140

121141
def compute_metrics(self) -> Dict[int, Dict[str, float]]:
122-
"""
123-
Compute per-class precision, recall, AP, and mAP.
142+
"""Compute per-class precision, recall, AP, and mAP.
124143
125-
Returns:
126-
Dict[class_id, Dict[str, float]], plus an entry for mAP under key -1
144+
:return: Dictionary mapping class IDs to metric dictionaries, plus mAP under key -1
145+
:rtype: Dict[int, Dict[str, float]]
127146
"""
128147
metrics = {}
129148
ap_values = []
@@ -164,11 +183,10 @@ def compute_metrics(self) -> Dict[int, Dict[str, float]]:
164183
return metrics
165184

166185
def compute_coco_map(self) -> float:
167-
"""
168-
Compute COCO-style mAP (mean AP over IoU thresholds 0.5:0.05:0.95).
186+
"""Compute COCO-style mAP (mean AP over IoU thresholds 0.5:0.05:0.95).
169187
170-
Returns:
171-
float: mAP@[0.5:0.95]
188+
:return: mAP@[0.5:0.95]
189+
:rtype: float
172190
"""
173191
iou_thresholds = np.arange(0.5, 1.0, 0.05)
174192
aps = []
@@ -240,11 +258,10 @@ def compute_coco_map(self) -> float:
240258
return np.mean(aps) if aps else 0.0
241259

242260
def get_overall_precision_recall_curve(self) -> Dict[str, List[float]]:
243-
"""
244-
Get overall precision-recall curve data (aggregated across all classes).
261+
"""Get overall precision-recall curve data (aggregated across all classes).
245262
246-
Returns:
247-
Dict[str, List[float]] with keys 'precision' and 'recall'
263+
:return: Dictionary with 'precision' and 'recall' keys containing curve data
264+
:rtype: Dict[str, List[float]]
248265
"""
249266
all_detections = []
250267

@@ -274,11 +291,10 @@ def get_overall_precision_recall_curve(self) -> Dict[str, List[float]]:
274291
}
275292

276293
def compute_auc_pr(self) -> float:
277-
"""
278-
Compute the Area Under the Precision-Recall Curve (AUC-PR).
294+
"""Compute the Area Under the Precision-Recall Curve (AUC-PR).
279295
280-
Returns:
281-
float: Area under the precision-recall curve
296+
:return: Area under the precision-recall curve
297+
:rtype: float
282298
"""
283299
curve_data = self.get_overall_precision_recall_curve()
284300
precision = np.array(curve_data["precision"])
@@ -299,10 +315,12 @@ def compute_auc_pr(self) -> float:
299315
return float(auc)
300316

301317
def get_metrics_dataframe(self, ontology: dict) -> pd.DataFrame:
302-
"""
303-
Get results as a pandas DataFrame.
318+
"""Get results as a pandas DataFrame.
304319
305320
:param ontology: Mapping from class name → { "idx": int }
321+
:type ontology: dict
322+
:return: DataFrame with metrics as rows and classes as columns (with mean)
323+
:rtype: pd.DataFrame
306324
"""
307325
all_metrics = self.compute_metrics()
308326
# Build a dict: metric -> {class_name: value}
@@ -340,8 +358,14 @@ def get_metrics_dataframe(self, ontology: dict) -> pd.DataFrame:
340358

341359

342360
def compute_iou_matrix(pred_boxes: np.ndarray, gt_boxes: np.ndarray) -> np.ndarray:
343-
"""
344-
Compute IoU matrix between pred and gt boxes.
361+
"""Compute IoU matrix between pred and gt boxes.
362+
363+
:param pred_boxes: Predicted bounding boxes, shape (num_pred, 4)
364+
:type pred_boxes: np.ndarray
365+
:param gt_boxes: Ground truth bounding boxes, shape (num_gt, 4)
366+
:type gt_boxes: np.ndarray
367+
:return: IoU matrix with shape (num_pred, num_gt)
368+
:rtype: np.ndarray
345369
"""
346370
iou_matrix = np.zeros((len(pred_boxes), len(gt_boxes)))
347371
for i, pb in enumerate(pred_boxes):
@@ -351,6 +375,15 @@ def compute_iou_matrix(pred_boxes: np.ndarray, gt_boxes: np.ndarray) -> np.ndarr
351375

352376

353377
def compute_iou(boxA, boxB):
378+
"""Compute Intersection over Union (IoU) between two bounding boxes.
379+
380+
:param boxA: First bounding box [x1, y1, x2, y2]
381+
:type boxA: array-like
382+
:param boxB: Second bounding box [x1, y1, x2, y2]
383+
:type boxB: array-like
384+
:return: IoU value between 0 and 1
385+
:rtype: float
386+
"""
354387
xA = max(boxA[0], boxB[0])
355388
yA = max(boxA[1], boxB[1])
356389
xB = min(boxA[2], boxB[2])
@@ -364,6 +397,17 @@ def compute_iou(boxA, boxB):
364397

365398

366399
def compute_ap(tps, fps, fn):
400+
"""Compute Average Precision (AP) using VOC-style 11-point interpolation.
401+
402+
:param tps: List of true positive flags
403+
:type tps: List[bool] or np.ndarray
404+
:param fps: List of false positive flags
405+
:type fps: List[bool] or np.ndarray
406+
:param fn: Number of false negatives
407+
:type fn: int
408+
:return: Tuple of (AP, precision array, recall array)
409+
:rtype: Tuple[float, np.ndarray, np.ndarray]
410+
"""
367411
tps = np.array(tps, dtype=np.float32)
368412
fps = np.array(fps, dtype=np.float32)
369413

0 commit comments

Comments
 (0)