Skip to content

Commit 80dac84

Browse files
committed
Refactor detection handling to use DetectionResult and update related documentation
1 parent 146b702 commit 80dac84

File tree

9 files changed

+243
-222
lines changed

9 files changed

+243
-222
lines changed

docs/source/python/models/detection_model.md

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@ A single input image of shape (H, W, 3) where H and W are the height and width o
1212

1313
### Outputs
1414

15-
Detection model outputs a list of detection objects (i.e `list[Detection]`) wrapped in `DetectionResult`, each object containing the following attributes:
15+
Detection model outputs a `DetectionResult` objects containing the following attributes:
1616

17-
- `score` (float) - Confidence score of the object.
18-
- `id` (int) - Class label of the object.
19-
- `str_label` (str) - String label of the object.
20-
- `xmin` (int) - X-coordinate of the top-left corner of the bounding box.
21-
- `ymin` (int) - Y-coordinate of the top-left corner of the bounding box.
22-
- `xmax` (int) - X-coordinate of the bottom-right corner of the bounding box.
23-
- `ymax` (int) - Y-coordinate of the bottom-right corner of the bounding box.
17+
- `boxes` (np.ndarray) - Bounding boxes of the detected objects. Each in format of x1, y1, x2 y2.
18+
- `scores` (np.ndarray) - Confidence scores of the detected objects.
19+
- `labels` (np.ndarray) - Class labels of the detected objects.
20+
- `label_names` (list[str]) - List of class names of the detected objects.
2421

2522
## Example
2623

@@ -34,11 +31,9 @@ model = SSD.create_model("model.xml")
3431
# Forward pass
3532
predictions = model(image)
3633

37-
# Iterate over the segmented objects
38-
for pred_obj in predictions.objects:
39-
pred_score = pred_obj.score
40-
label_id = pred_obj.id
41-
bbox = [pred_obj.xmin, pred_obj.ymin, pred_obj.xmax, pred_obj.ymax]
34+
# Iterate over detection result
35+
for box, score, label in zip(predictions.boxes, predictions.scores, predictions.labels):
36+
print(f"Box: {box}, Score: {score}, Label: {label}")
4237
```
4338

4439
```{eval-rst}

model_api/python/model_api/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
ClassificationResult,
1717
Contour,
1818
DetectedKeypoints,
19-
Detection,
2019
DetectionResult,
2120
ImageResultWithSoftPrediction,
2221
InstanceSegmentationResult,
@@ -90,7 +89,6 @@
9089
"SAMImageEncoder",
9190
"ClassificationResult",
9291
"Prompt",
93-
"Detection",
9492
"DetectionResult",
9593
"DetectedKeypoints",
9694
"classification_models",

model_api/python/model_api/models/detection_model.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
import numpy as np
7+
68
from .image_model import ImageModel
7-
from .result_types import Detection
9+
from .result_types import DetectionResult
810
from .types import ListValue, NumericalValue, StringValue
911
from .utils import load_labels
1012

@@ -65,14 +67,14 @@ def parameters(cls):
6567

6668
return parameters
6769

68-
def _resize_detections(self, detections: list[Detection], meta):
70+
def _resize_detections(self, detection_result: DetectionResult, meta: dict) -> DetectionResult:
6971
"""Resizes detection bounding boxes according to initial image shape.
7072
7173
It implements image resizing depending on the set `resize_type`(see `ImageModel` for details).
7274
Next, it applies bounding boxes clipping.
7375
7476
Args:
75-
detections (List[Detection]): list of detections with coordinates in normalized form
77+
detection_result (DetectionList): detection result with coordinates in normalized form
7678
meta (dict): the input metadata obtained from `preprocess` method
7779
7880
Returns:
@@ -92,63 +94,35 @@ def _resize_detections(self, detections: list[Detection], meta):
9294
pad_left = (self.w - round(input_img_widht / inverted_scale_x)) // 2
9395
pad_top = (self.h - round(input_img_height / inverted_scale_y)) // 2
9496

95-
def _clamp_and_round(val, min_value, max_value):
96-
return round(max(min_value, min(max_value, val)))
97-
98-
for detection in detections:
99-
detection.xmin = _clamp_and_round(
100-
(detection.xmin * self.w - pad_left) * inverted_scale_x,
101-
0,
102-
input_img_widht,
103-
)
104-
detection.ymin = _clamp_and_round(
105-
(detection.ymin * self.h - pad_top) * inverted_scale_y,
106-
0,
107-
input_img_height,
108-
)
109-
detection.xmax = _clamp_and_round(
110-
(detection.xmax * self.w - pad_left) * inverted_scale_x,
111-
0,
112-
input_img_widht,
113-
)
114-
detection.ymax = _clamp_and_round(
115-
(detection.ymax * self.h - pad_top) * inverted_scale_y,
116-
0,
117-
input_img_height,
118-
)
97+
boxes = detection_result.bboxes
98+
boxes[:, 0::2] = (boxes[:, 0::2] * self.w - pad_left) * inverted_scale_x
99+
boxes[:, 1::2] = (boxes[:, 1::2] * self.h - pad_top) * inverted_scale_y
100+
boxes[:, 0::2] = np.clip(boxes[:, 0::2], 0, input_img_widht)
101+
boxes[:, 1::2] = np.clip(boxes[:, 1::2], 0, input_img_height)
102+
detection_result.bboxes = boxes
103+
return detection_result
119104

120-
return detections
121-
122-
def _filter_detections(self, detections: list[Detection], box_area_threshold=0.0):
105+
def _filter_detections(self, detection_result: DetectionResult, box_area_threshold=0.0):
123106
"""Filters detections by confidence threshold and box size threshold
124107
125108
Args:
126-
detections (List[Detection]): list of detections with coordinates in normalized form
109+
detection_result (DetectionResult): DetectionResult object with coordinates in normalized form
127110
box_area_threshold (float): minimal area of the bounding to be considered
128111
129112
Returns:
130113
- list of detections with confidence above the threshold
131114
"""
132-
filtered_detections = []
133-
for detection in detections:
134-
if (
135-
detection.score < self.confidence_threshold
136-
or (detection.xmax - detection.xmin) * (detection.ymax - detection.ymin) < box_area_threshold
137-
):
138-
continue
139-
filtered_detections.append(detection)
140-
141-
return filtered_detections
142-
143-
def _add_label_names(self, detections: list[Detection]):
115+
keep = (detection_result.get_obj_sizes() > box_area_threshold) & (
116+
detection_result.scores > self.confidence_threshold
117+
)
118+
detection_result.bboxes = detection_result.bboxes[keep]
119+
detection_result.labels = detection_result.labels[keep]
120+
detection_result.scores = detection_result.scores[keep]
121+
122+
def _add_label_names(self, detection_result: DetectionResult) -> None:
144123
"""Adds labels names to detections if they are available
145124
146125
Args:
147-
detections (List[Detection]): list of detections with coordinates in normalized form
148-
149-
Returns:
150-
- list of detections with label strings
126+
detection_result (List[Detection]): list of detections with coordinates in normalized form
151127
"""
152-
for detection in detections:
153-
detection.str_label = self.get_label_name(detection.id)
154-
return detections
128+
detection_result.label_names = [self.get_label_name(label_idx) for label_idx in detection_result.labels]

model_api/python/model_api/models/keypoint_detection.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111

1212
from .image_model import ImageModel
13-
from .result_types import DetectedKeypoints, Detection
13+
from .result_types import DetectedKeypoints, DetectionResult
1414
from .types import ListValue
1515

1616

@@ -77,25 +77,27 @@ def __init__(self, base_model: KeypointDetectionModel) -> None:
7777
def predict(
7878
self,
7979
image: np.ndarray,
80-
detections: list[Detection],
80+
detection_result: DetectionResult,
8181
) -> list[DetectedKeypoints]:
8282
"""Predicts keypoints for the given image and detections.
8383
8484
Args:
8585
image (np.ndarray): input full-size image
86-
detections (list[Detection]): detections located within the given image
86+
detection_result (detection_result): detections located within the given image
8787
8888
Returns:
8989
list[DetectedKeypoints]: per detection keypoints in detection coordinates
9090
"""
9191
crops = []
92-
for det in detections:
93-
crops.append(image[det.ymin : det.ymax, det.xmin : det.xmax])
92+
for box in detection_result.bboxes:
93+
x1, y1, x2, y2 = box
94+
crops.append(image[y1:y2, x1:x2])
9495

9596
crops_results = self.predict_crops(crops)
96-
for i, det in enumerate(detections):
97+
for i, box in enumerate(detection_result.bboxes):
98+
x1, y1, x2, y2 = box
9799
crops_results[i] = DetectedKeypoints(
98-
crops_results[i].keypoints + np.array([det.xmin, det.ymin]),
100+
crops_results[i].keypoints + np.array([x1, y1]),
99101
crops_results[i].scores,
100102
)
101103

model_api/python/model_api/models/result_types/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55

66
from .anomaly import AnomalyResult
77
from .classification import ClassificationResult
8-
from .detection import Detection, DetectionResult
8+
from .detection import (
9+
BoxesLabelsParser,
10+
DetectionResult,
11+
MultipleOutputParser,
12+
SingleOutputParser,
13+
)
914
from .keypoint import DetectedKeypoints
1015
from .segmentation import (
1116
Contour,
@@ -18,13 +23,15 @@
1823

1924
__all__ = [
2025
"AnomalyResult",
26+
"BoxesLabelsParser",
2127
"ClassificationResult",
2228
"Contour",
23-
"Detection",
2429
"DetectionResult",
2530
"DetectedKeypoints",
31+
"MultipleOutputParser",
2632
"SegmentedObject",
2733
"SegmentedObjectWithRects",
34+
"SingleOutputParser",
2835
"ImageResultWithSoftPrediction",
2936
"InstanceSegmentationResult",
3037
"PredictedMask",

0 commit comments

Comments
 (0)