Skip to content

Commit c0b51d6

Browse files
committed
Testing Pytorch Yolo Implementation
Signed-off-by: [email protected] <[email protected]>
1 parent e3bde9c commit c0b51d6

File tree

1 file changed

+81
-6
lines changed

1 file changed

+81
-6
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717
# SOFTWARE.
1818
"""
19-
This module implements the task specific estimator for PyTorch YOLO v3 and v5 object detectors.
19+
This module implements the task specific estimator for PyTorch YOLO v3, v5, v8+ object detectors.
2020
2121
| Paper link: https://arxiv.org/abs/1804.02767
2222
"""
2323
from __future__ import annotations
2424

2525
import logging
26-
from typing import TYPE_CHECKING
26+
from typing import TYPE_CHECKING, Optional, Union
2727

2828
import numpy as np
29+
import torch
2930

3031
from art.estimators.object_detection.pytorch_object_detector import PyTorchObjectDetector
3132

@@ -40,9 +41,65 @@
4041
logger = logging.getLogger(__name__)
4142

4243

44+
class PyTorchYoloLossWrapper(torch.nn.Module):
45+
"""Wrapper for YOLO v8+ models to handle loss dict format."""
46+
47+
def __init__(self, model, name):
48+
super().__init__()
49+
self.model = model
50+
try:
51+
from ultralytics.models.yolo.detect import DetectionPredictor
52+
from ultralytics.utils.loss import v8DetectionLoss, E2EDetectLoss
53+
54+
self.detection_predictor = DetectionPredictor()
55+
self.model.args = self.detection_predictor.args
56+
if 'v10' in name:
57+
self.model.criterion = E2EDetectLoss(model)
58+
else:
59+
self.model.criterion = v8DetectionLoss(model)
60+
except ImportError as e:
61+
raise ImportError("The 'ultralytics' package is required for YOLO v8+ models but not installed.") from e
62+
63+
def forward(self, x, targets=None):
64+
if self.training:
65+
# batch_idx is used to identify which predictions/boxes relate to which image
66+
boxes = []
67+
labels = []
68+
indices = []
69+
for i, item in enumerate(targets):
70+
boxes.append(item['boxes'])
71+
labels.append(item['labels'])
72+
indices = indices + ([i]*len(item['labels']))
73+
items = {'boxes': torch.concatenate(boxes) / x.shape[2],
74+
'labels': torch.concatenate(labels).type(torch.float32),
75+
'batch_idx': torch.tensor(indices)}
76+
items['bboxes'] = items.pop('boxes')
77+
items['cls'] = items.pop('labels')
78+
items['img'] = x
79+
80+
loss, loss_components = self.model.loss(items)
81+
loss_components_dict = {"loss_total": loss.sum()}
82+
loss_components_dict['loss_box'] = loss_components[0]
83+
loss_components_dict['loss_cls'] = loss_components[1]
84+
loss_components_dict['loss_dfl'] = loss_components[2]
85+
return loss_components_dict
86+
else:
87+
preds = self.model(x)
88+
self.detection_predictor.model = self.model
89+
self.detection_predictor.batch = [x]
90+
preds = self.detection_predictor.postprocess(preds, x, x)
91+
# translate the preds to ART supported format
92+
items = []
93+
for pred in preds:
94+
items.append({'boxes': pred.boxes.xyxy,
95+
'scores': pred.boxes.conf,
96+
'labels': pred.boxes.cls.type(torch.int)})
97+
return items
98+
99+
43100
class PyTorchYolo(PyTorchObjectDetector):
44101
"""
45-
This module implements the model- and task specific estimator for YOLO v3, v5 object detector models in PyTorch.
102+
This module implements the model- and task specific estimator for YOLO v3, v5, v8+ object detector models in PyTorch.
46103
47104
| Paper link: https://arxiv.org/abs/1804.02767
48105
"""
@@ -65,11 +122,12 @@ def __init__(
65122
),
66123
device_type: str = "gpu",
67124
is_yolov8: bool = False,
125+
model_name: str = "",
68126
):
69127
"""
70128
Initialization.
71129
72-
:param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
130+
:param model: YOLO v3, v5, or v8+ model wrapped as demonstrated in examples/get_started_yolo.py.
73131
The output of the model is `list[dict[str, torch.Tensor]]`, one for each input image.
74132
The fields of the dict are as follows:
75133
@@ -93,8 +151,13 @@ def __init__(
93151
'loss_objectness', and 'loss_rpn_box_reg'.
94152
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
95153
if available otherwise run on CPU.
96-
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
154+
:param is_yolov8: The flag to be used for marking the YOLOv8+ model.
155+
:param model_name: The name of the model (e.g., 'yolov8n', 'yolov10n') for determining loss function.
97156
"""
157+
# Wrap the model with YoloWrapper if it's a YOLO v8+ model
158+
if is_yolov8:
159+
model = YoloWrapper(model, model_name)
160+
98161
super().__init__(
99162
model=model,
100163
input_shape=input_shape,
@@ -154,11 +217,23 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
154217
Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
155218
convert tensors to numpy arrays.
156219
157-
:param predictions: Object detection labels in format xcycwh (YOLO).
220+
:param predictions: Object detection labels in format xcycwh (YOLO) or list of dicts (YOLO v8+).
158221
:return: Object detection labels in format x1y1x2y2 (torchvision).
159222
"""
160223
import torch
161224

225+
# Handle YOLO v8+ predictions (list of dicts)
226+
if isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
227+
predictions_x1y1x2y2: list[dict[str, np.ndarray]] = []
228+
for pred in predictions:
229+
prediction = {}
230+
prediction["boxes"] = pred["boxes"].detach().cpu().numpy()
231+
prediction["labels"] = pred["labels"].detach().cpu().numpy()
232+
prediction["scores"] = pred["scores"].detach().cpu().numpy()
233+
predictions_x1y1x2y2.append(prediction)
234+
return predictions_x1y1x2y2
235+
236+
# Handle traditional YOLO predictions (tensor format)
162237
if self.channels_first:
163238
height = self.input_shape[1]
164239
width = self.input_shape[2]

0 commit comments

Comments
 (0)