Skip to content

Commit 6c41ca8

Browse files
authored
Merge pull request #2675 from arjun-sachar/dev_1.20.0
Implementing Yolo v8+ Dependencies
2 parents e3bde9c + 0911f23 commit 6c41ca8

File tree

5 files changed

+591
-7
lines changed

5 files changed

+591
-7
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717
# SOFTWARE.
1818
"""
19-
This module implements the task specific estimator for PyTorch YOLO v3 and v5 object detectors.
19+
This module implements the task specific estimator for PyTorch YOLO v3, v5, v8+ object detectors.
2020
2121
| Paper link: https://arxiv.org/abs/1804.02767
2222
"""
@@ -42,7 +42,7 @@
4242

4343
class PyTorchYolo(PyTorchObjectDetector):
4444
"""
45-
This module implements the model- and task specific estimator for YOLO v3, v5 object detector models in PyTorch.
45+
This module implements the model- and task specific estimator for YOLO object detector models in PyTorch.
4646
4747
| Paper link: https://arxiv.org/abs/1804.02767
4848
"""
@@ -65,11 +65,12 @@ def __init__(
6565
),
6666
device_type: str = "gpu",
6767
is_yolov8: bool = False,
68+
model_name: str | None = None,
6869
):
6970
"""
7071
Initialization.
7172
72-
:param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
73+
:param model: YOLO v3, v5, or v8+ model wrapped as demonstrated in examples/get_started_yolo.py.
7374
The output of the model is `list[dict[str, torch.Tensor]]`, one for each input image.
7475
The fields of the dict are as follows:
7576
@@ -93,8 +94,15 @@ def __init__(
9394
'loss_objectness', and 'loss_rpn_box_reg'.
9495
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
9596
if available otherwise run on CPU.
96-
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
97+
:param is_yolov8: The flag to be used for marking the YOLOv8+ model.
98+
:param model_name: The name of the model (e.g., 'yolov8n', 'yolov10n') for determining loss function.
9799
"""
100+
# Wrap the model with YoloWrapper if it's a YOLO v8+ model
101+
if is_yolov8:
102+
from art.estimators.object_detection.pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
103+
104+
model = PyTorchYoloLossWrapper(model, model_name)
105+
98106
super().__init__(
99107
model=model,
100108
input_shape=input_shape,
@@ -154,20 +162,31 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
154162
Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
155163
convert tensors to numpy arrays.
156164
157-
:param predictions: Object detection labels in format xcycwh (YOLO).
165+
:param predictions: Object detection labels in format xcycwh (YOLO) or list of dicts (YOLO v8+).
158166
:return: Object detection labels in format x1y1x2y2 (torchvision).
159167
"""
160168
import torch
161169

170+
predictions_x1y1x2y2: list[dict[str, np.ndarray]] = []
171+
172+
# Handle YOLO v8+ predictions (list of dicts)
173+
if isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
174+
for pred in predictions:
175+
prediction = {}
176+
prediction["boxes"] = pred["boxes"].detach().cpu().numpy()
177+
prediction["labels"] = pred["labels"].detach().cpu().numpy()
178+
prediction["scores"] = pred["scores"].detach().cpu().numpy()
179+
predictions_x1y1x2y2.append(prediction)
180+
return predictions_x1y1x2y2
181+
182+
# Handle traditional YOLO predictions (tensor format)
162183
if self.channels_first:
163184
height = self.input_shape[1]
164185
width = self.input_shape[2]
165186
else:
166187
height = self.input_shape[0]
167188
width = self.input_shape[1]
168189

169-
predictions_x1y1x2y2: list[dict[str, np.ndarray]] = []
170-
171190
for pred in predictions:
172191
boxes = torch.vstack(
173192
[
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
PyTorch-specific YOLO loss wrapper for ART for yolo versions 8 and above.
20+
"""
21+
22+
import torch
23+
24+
25+
class PyTorchYoloLossWrapper(torch.nn.Module):
26+
"""Wrapper for YOLO v8+ models to handle loss dict format."""
27+
28+
def __init__(self, model, name):
29+
super().__init__()
30+
self.model = model
31+
try:
32+
from ultralytics.models.yolo.detect import DetectionPredictor
33+
from ultralytics.utils.loss import v8DetectionLoss, E2EDetectLoss
34+
35+
self.detection_predictor = DetectionPredictor()
36+
self.model.args = self.detection_predictor.args
37+
if "v10" in name:
38+
self.model.criterion = E2EDetectLoss(model)
39+
else:
40+
self.model.criterion = v8DetectionLoss(model)
41+
except ImportError as e:
42+
raise ImportError("The 'ultralytics' package is required for YOLO v8+ models but not installed.") from e
43+
44+
def forward(self, x, targets=None):
45+
if self.training:
46+
boxes = []
47+
labels = []
48+
indices = []
49+
for i, item in enumerate(targets):
50+
boxes.append(item["boxes"])
51+
labels.append(item["labels"])
52+
indices = indices + ([i] * len(item["labels"]))
53+
items = {
54+
"boxes": torch.cat(boxes) / x.shape[2],
55+
"labels": torch.cat(labels).type(torch.float32),
56+
"batch_idx": torch.tensor(indices),
57+
}
58+
items["bboxes"] = items.pop("boxes")
59+
items["cls"] = items.pop("labels")
60+
items["img"] = x
61+
loss, loss_components = self.model.loss(items)
62+
loss_components_dict = {"loss_total": loss.sum()}
63+
loss_components_dict["loss_box"] = loss_components[0].sum()
64+
loss_components_dict["loss_cls"] = loss_components[1].sum()
65+
loss_components_dict["loss_dfl"] = loss_components[2].sum()
66+
return loss_components_dict
67+
else:
68+
preds = self.model(x)
69+
self.detection_predictor.model = self.model
70+
self.detection_predictor.batch = [x]
71+
preds = self.detection_predictor.postprocess(preds, x, x)
72+
items = []
73+
for pred in preds:
74+
items.append(
75+
{"boxes": pred.boxes.xyxy, "scores": pred.boxes.conf, "labels": pred.boxes.cls.type(torch.int)}
76+
)
77+
return items

requirements_test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ torchvision==0.22.1
3535
# PyTorch image transformers
3636
timm==1.0.15
3737

38+
# YOLO dependencies
39+
ultralytics==8.3.159
40+
3841
catboost==1.2.8
3942
GPy==1.13.2
4043
lightgbm==4.6.0

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,39 @@ def test_patch(art_warning, get_pytorch_yolo):
322322

323323
except ARTTestException as e:
324324
art_warning(e)
325+
326+
327+
def test_import_pytorch_yolo_loss_wrapper():
328+
import torch
329+
from art.estimators.object_detection.pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
330+
331+
class DummyModel(torch.nn.Module):
332+
def __init__(self):
333+
super().__init__()
334+
335+
def loss(self, items):
336+
return (torch.tensor([1.0]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)])
337+
338+
test_model = DummyModel()
339+
# Patch ultralytics import in the wrapper
340+
import sys
341+
import types
342+
343+
ultralytics_mock = types.SimpleNamespace(
344+
models=types.SimpleNamespace(
345+
yolo=types.SimpleNamespace(
346+
detect=types.SimpleNamespace(DetectionPredictor=lambda: types.SimpleNamespace(args=None))
347+
)
348+
),
349+
utils=types.SimpleNamespace(
350+
loss=types.SimpleNamespace(v8DetectionLoss=lambda m: None, E2EDetectLoss=lambda m: None)
351+
),
352+
)
353+
sys.modules["ultralytics"] = ultralytics_mock
354+
sys.modules["ultralytics.models"] = ultralytics_mock.models
355+
sys.modules["ultralytics.models.yolo"] = ultralytics_mock.models.yolo
356+
sys.modules["ultralytics.models.yolo.detect"] = ultralytics_mock.models.yolo.detect
357+
sys.modules["ultralytics.utils"] = ultralytics_mock.utils
358+
sys.modules["ultralytics.utils.loss"] = ultralytics_mock.utils.loss
359+
wrapper = PyTorchYoloLossWrapper(test_model, name="yolov8n")
360+
assert isinstance(wrapper, PyTorchYoloLossWrapper)

0 commit comments

Comments
 (0)