Skip to content

Commit bf2ff89

Browse files
committed
Final Edits for Yolo Dependency Allignment
Signed-off-by: [email protected] <[email protected]>
1 parent c437e6d commit bf2ff89

File tree

4 files changed

+134
-9
lines changed

4 files changed

+134
-9
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
from __future__ import annotations
2424

2525
import logging
26-
from typing import TYPE_CHECKING, Optional, Union
26+
from typing import TYPE_CHECKING
2727

2828
import numpy as np
29-
import torch
3029

3130
from art.estimators.object_detection.pytorch_object_detector import PyTorchObjectDetector
3231

@@ -66,7 +65,7 @@ def __init__(
6665
),
6766
device_type: str = "gpu",
6867
is_yolov8: bool = False,
69-
model_name: str = "",
68+
model_name: str | None = None,
7069
):
7170
"""
7271
Initialization.
@@ -168,9 +167,10 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
168167
"""
169168
import torch
170169

170+
predictions_x1y1x2y2: list[dict[str, np.ndarray]] = []
171+
171172
# Handle YOLO v8+ predictions (list of dicts)
172173
if isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
173-
predictions_x1y1x2y2: list[dict[str, np.ndarray]] = []
174174
for pred in predictions:
175175
prediction = {}
176176
prediction["boxes"] = pred["boxes"].detach().cpu().numpy()
@@ -187,8 +187,6 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
187187
height = self.input_shape[0]
188188
width = self.input_shape[1]
189189

190-
predictions_x1y1x2y2: list[dict[str, np.ndarray]] = []
191-
192190
for pred in predictions:
193191
boxes = torch.vstack(
194192
[

art/estimators/object_detection/pytorch_yolo_loss_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
66
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the

requirements_test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ torchvision==0.22.1
3535
# PyTorch image transformers
3636
timm==1.0.15
3737

38+
# YOLO dependencies
39+
ultralytics==8.3.159
40+
3841
catboost==1.2.8
3942
GPy==1.13.2
4043
lightgbm==4.6.0

tests/estimators/object_detection/test_pytorch_yolo_loss_wrapper.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,135 @@
11
# MIT License
22
#
3-
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
44
#
5-
# Test for PyTorchYoloLossWrapper
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
618
import pytest
719
import torch
20+
import os
821
from art.estimators.object_detection.pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
22+
from ultralytics import YOLO
23+
24+
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
25+
26+
27+
@pytest.mark.only_with_platform("pytorch")
28+
def test_yolov8_loss_wrapper():
29+
"""Test the loss wrapper with YOLOv8 model."""
30+
# Load YOLOv8 model
31+
model_path = "/tmp/yolo_v8.3.0/yolov8n.pt"
32+
model = YOLO(model_path).model
33+
34+
# Create wrapper
35+
wrapper = PyTorchYoloLossWrapper(model, name="yolov8n")
36+
wrapper.train()
37+
38+
# Create sample input
39+
batch_size = 2
40+
x = torch.randn((batch_size, 3, 640, 640)) # YOLOv8 expects (B, 3, 640, 640)
41+
42+
# Create targets
43+
targets = []
44+
for _ in range(batch_size):
45+
boxes = torch.tensor([[0.1, 0.1, 0.3, 0.3], [0.5, 0.5, 0.8, 0.8]]) # [x1, y1, x2, y2]
46+
labels = torch.zeros(2, dtype=torch.long) # Use class 0 for testing
47+
targets.append({"boxes": boxes, "labels": labels})
48+
49+
# Test training mode
50+
losses = wrapper(x, targets)
51+
52+
# Validate loss structure
53+
expected_loss_keys = {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
54+
assert set(losses.keys()) == expected_loss_keys
55+
assert all(isinstance(v, torch.Tensor) for v in losses.values())
56+
assert all(not torch.isnan(v).any() for v in losses.values()), "Loss values contain NaN"
57+
assert all(not torch.isinf(v).any() for v in losses.values()), "Loss values contain Inf"
58+
59+
# Test inference mode
60+
wrapper.eval()
61+
with torch.no_grad():
62+
predictions = wrapper(x)
63+
64+
# Validate predictions
65+
assert isinstance(predictions, list)
66+
assert len(predictions) == batch_size
67+
for pred in predictions:
68+
assert set(pred.keys()) == {"boxes", "scores", "labels"}
69+
assert isinstance(pred["boxes"], torch.Tensor)
70+
assert isinstance(pred["scores"], torch.Tensor)
71+
assert isinstance(pred["labels"], torch.Tensor)
72+
assert pred["boxes"].ndim == 2 and pred["boxes"].shape[1] == 4
73+
assert pred["scores"].ndim == 1
74+
assert pred["labels"].ndim == 1
75+
assert pred["scores"].shape[0] == pred["labels"].shape[0] == pred["boxes"].shape[0]
76+
assert pred["boxes"].dtype == torch.float32
77+
assert pred["labels"].dtype in (torch.int32, torch.int64)
78+
79+
80+
@pytest.mark.only_with_platform("pytorch")
81+
def test_yolov10_loss_wrapper():
82+
"""Test the loss wrapper with YOLOv10 model."""
83+
# Load YOLOv10 model
84+
model_path = "/tmp/yolo_v8.3.0/yolov10n.pt"
85+
model = YOLO(model_path).model
86+
87+
# Create wrapper
88+
wrapper = PyTorchYoloLossWrapper(model, name="yolov10n")
89+
wrapper.train()
90+
91+
# Create sample input
92+
batch_size = 2
93+
x = torch.randn((batch_size, 3, 640, 640)) # Standard YOLO input size
94+
95+
# Create targets
96+
targets = []
97+
for _ in range(batch_size):
98+
boxes = torch.tensor([[0.1, 0.1, 0.3, 0.3], [0.5, 0.5, 0.8, 0.8]]) # [x1, y1, x2, y2]
99+
labels = torch.zeros(2, dtype=torch.long) # Use class 0 for testing
100+
targets.append({"boxes": boxes, "labels": labels})
101+
102+
# Test training mode
103+
losses = wrapper(x, targets)
104+
105+
# Validate loss structure
106+
expected_loss_keys = {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
107+
assert set(losses.keys()) == expected_loss_keys
108+
assert all(isinstance(v, torch.Tensor) for v in losses.values())
109+
assert all(not torch.isnan(v).any() for v in losses.values()), "Loss values contain NaN"
110+
assert all(not torch.isinf(v).any() for v in losses.values()), "Loss values contain Inf"
111+
assert all(v.item() >= 0 for v in losses.values()), "Loss values should be non-negative"
112+
assert losses["loss_total"].item() > 0, "Total loss should be positive"
113+
114+
# Test inference mode
115+
wrapper.eval()
116+
with torch.no_grad():
117+
predictions = wrapper(x)
118+
119+
# Validate predictions
120+
assert isinstance(predictions, list)
121+
assert len(predictions) == batch_size
122+
for pred in predictions:
123+
assert set(pred.keys()) == {"boxes", "scores", "labels"}
124+
assert isinstance(pred["boxes"], torch.Tensor)
125+
assert isinstance(pred["scores"], torch.Tensor)
126+
assert isinstance(pred["labels"], torch.Tensor)
127+
assert pred["boxes"].ndim == 2 and pred["boxes"].shape[1] == 4
128+
assert pred["scores"].ndim == 1
129+
assert pred["labels"].ndim == 1
130+
assert pred["scores"].shape[0] == pred["labels"].shape[0] == pred["boxes"].shape[0]
131+
assert pred["boxes"].dtype == torch.float32
132+
assert pred["labels"].dtype in (torch.int32, torch.int64)
9133

10134

11135
@pytest.mark.only_with_platform("pytorch")

0 commit comments

Comments
 (0)