Skip to content

Commit ee880b8

Browse files
committed
Committing Changes for Yolo versioning
Signed-off-by: [email protected] <[email protected]>
1 parent 30e57d1 commit ee880b8

File tree

4 files changed

+427
-130
lines changed

4 files changed

+427
-130
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 4 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -41,62 +41,6 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44-
class PyTorchYoloLossWrapper(torch.nn.Module):
45-
"""Wrapper for YOLO v8+ models to handle loss dict format."""
46-
47-
def __init__(self, model, name):
48-
super().__init__()
49-
self.model = model
50-
try:
51-
from ultralytics.models.yolo.detect import DetectionPredictor
52-
from ultralytics.utils.loss import v8DetectionLoss, E2EDetectLoss
53-
54-
self.detection_predictor = DetectionPredictor()
55-
self.model.args = self.detection_predictor.args
56-
if 'v10' in name:
57-
self.model.criterion = E2EDetectLoss(model)
58-
else:
59-
self.model.criterion = v8DetectionLoss(model)
60-
except ImportError as e:
61-
raise ImportError("The 'ultralytics' package is required for YOLO v8+ models but not installed.") from e
62-
63-
def forward(self, x, targets=None):
64-
if self.training:
65-
# batch_idx is used to identify which predictions/boxes relate to which image
66-
boxes = []
67-
labels = []
68-
indices = []
69-
for i, item in enumerate(targets):
70-
boxes.append(item['boxes'])
71-
labels.append(item['labels'])
72-
indices = indices + ([i]*len(item['labels']))
73-
items = {'boxes': torch.cat(boxes) / x.shape[2],
74-
'labels': torch.cat(labels).type(torch.float32),
75-
'batch_idx': torch.tensor(indices)}
76-
items['bboxes'] = items.pop('boxes')
77-
items['cls'] = items.pop('labels')
78-
items['img'] = x
79-
80-
loss, loss_components = self.model.loss(items)
81-
loss_components_dict = {"loss_total": loss.sum()}
82-
loss_components_dict['loss_box'] = loss_components[0]
83-
loss_components_dict['loss_cls'] = loss_components[1]
84-
loss_components_dict['loss_dfl'] = loss_components[2]
85-
return loss_components_dict
86-
else:
87-
preds = self.model(x)
88-
self.detection_predictor.model = self.model
89-
self.detection_predictor.batch = [x]
90-
preds = self.detection_predictor.postprocess(preds, x, x)
91-
# translate the preds to ART supported format
92-
items = []
93-
for pred in preds:
94-
items.append({'boxes': pred.boxes.xyxy,
95-
'scores': pred.boxes.conf,
96-
'labels': pred.boxes.cls.type(torch.int)})
97-
return items
98-
99-
10044
class PyTorchYolo(PyTorchObjectDetector):
10145
"""
10246
This module implements the model- and task specific estimator for YOLO v3, v5, v8+ object detector models in PyTorch.
@@ -156,8 +100,10 @@ def __init__(
156100
"""
157101
# Wrap the model with YoloWrapper if it's a YOLO v8+ model
158102
if is_yolov8:
159-
model = YoloWrapper(model, model_name)
160-
103+
from art.estimators.object_detection.pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
104+
105+
model = PyTorchYoloLossWrapper(model, model_name)
106+
161107
super().__init__(
162108
model=model,
163109
input_shape=input_shape,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
PyTorch-specific YOLO loss wrapper for ART for yolo versions 8 and above.
20+
"""
21+
22+
import torch
23+
24+
25+
class PyTorchYoloLossWrapper(torch.nn.Module):
26+
"""Wrapper for YOLO v8+ models to handle loss dict format."""
27+
28+
def __init__(self, model, name):
29+
super().__init__()
30+
self.model = model
31+
try:
32+
from ultralytics.models.yolo.detect import DetectionPredictor
33+
from ultralytics.utils.loss import v8DetectionLoss, E2EDetectLoss
34+
35+
self.detection_predictor = DetectionPredictor()
36+
self.model.args = self.detection_predictor.args
37+
if "v10" in name:
38+
self.model.criterion = E2EDetectLoss(model)
39+
else:
40+
self.model.criterion = v8DetectionLoss(model)
41+
except ImportError as e:
42+
raise ImportError("The 'ultralytics' package is required for YOLO v8+ models but not installed.") from e
43+
44+
def forward(self, x, targets=None):
45+
if self.training:
46+
boxes = []
47+
labels = []
48+
indices = []
49+
for i, item in enumerate(targets):
50+
boxes.append(item["boxes"])
51+
labels.append(item["labels"])
52+
indices = indices + ([i] * len(item["labels"]))
53+
items = {
54+
"boxes": torch.cat(boxes) / x.shape[2],
55+
"labels": torch.cat(labels).type(torch.float32),
56+
"batch_idx": torch.tensor(indices),
57+
}
58+
items["bboxes"] = items.pop("boxes")
59+
items["cls"] = items.pop("labels")
60+
items["img"] = x
61+
loss, loss_components = self.model.loss(items)
62+
loss_components_dict = {"loss_total": loss.sum()}
63+
loss_components_dict["loss_box"] = loss_components[0].sum()
64+
loss_components_dict["loss_cls"] = loss_components[1].sum()
65+
loss_components_dict["loss_dfl"] = loss_components[2].sum()
66+
return loss_components_dict
67+
else:
68+
preds = self.model(x)
69+
self.detection_predictor.model = self.model
70+
self.detection_predictor.batch = [x]
71+
preds = self.detection_predictor.postprocess(preds, x, x)
72+
items = []
73+
for pred in preds:
74+
items.append(
75+
{"boxes": pred.boxes.xyxy, "scores": pred.boxes.conf, "labels": pred.boxes.cls.type(torch.int)}
76+
)
77+
return items

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 20 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -324,89 +324,37 @@ def test_patch(art_warning, get_pytorch_yolo):
324324
art_warning(e)
325325

326326

327-
@pytest.mark.only_with_platform("pytorch")
328-
def test_translate_predictions_yolov8_format():
327+
def test_import_pytorch_yolo_loss_wrapper():
329328
import torch
330-
import numpy as np
331-
from art.estimators.object_detection.pytorch_yolo import PyTorchYolo
329+
from art.estimators.object_detection.pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
332330

333-
# Create a dummy PyTorchYolo instance (model is not used for this test)
334-
class DummyModel(torch.nn.Module):
335-
def forward(self, x):
336-
return x
337-
dummy_model = DummyModel()
338-
yolo = PyTorchYolo(
339-
model=dummy_model,
340-
input_shape=(3, 416, 416),
341-
optimizer=None,
342-
clip_values=(0, 1),
343-
channels_first=True,
344-
attack_losses=("loss_total",),
345-
)
346-
347-
# Mock YOLO v8+ style predictions: list of dicts with torch tensors
348-
pred_boxes = torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)
349-
pred_labels = torch.tensor([5], dtype=torch.int64)
350-
pred_scores = torch.tensor([0.9], dtype=torch.float32)
351-
predictions = [{
352-
"boxes": pred_boxes,
353-
"labels": pred_labels,
354-
"scores": pred_scores,
355-
}]
356-
357-
# Call the translation method
358-
translated = yolo._translate_predictions(predictions)
359-
360-
# Check output type and values
361-
assert isinstance(translated, list)
362-
assert isinstance(translated[0], dict)
363-
assert isinstance(translated[0]["boxes"], np.ndarray)
364-
assert isinstance(translated[0]["labels"], np.ndarray)
365-
assert isinstance(translated[0]["scores"], np.ndarray)
366-
np.testing.assert_array_equal(translated[0]["boxes"], pred_boxes.numpy())
367-
np.testing.assert_array_equal(translated[0]["labels"], pred_labels.numpy())
368-
np.testing.assert_array_equal(translated[0]["scores"], pred_scores.numpy())
369-
370-
371-
@pytest.mark.only_with_platform("pytorch")
372-
def test_pytorch_yolo_loss_wrapper_additional_losses():
373-
import torch
374-
from art.estimators.object_detection.pytorch_yolo import PyTorchYoloLossWrapper
375-
376-
# Dummy model with a .loss() method
377331
class DummyModel(torch.nn.Module):
378332
def __init__(self):
379333
super().__init__()
334+
380335
def loss(self, items):
381-
# Return (loss, [loss_box, loss_cls, loss_dfl])
382-
return (
383-
torch.tensor([1.0, 2.0, 3.0]),
384-
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]
385-
)
336+
return (torch.tensor([1.0]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)])
386337

387338
dummy_model = DummyModel()
388339
# Patch ultralytics import in the wrapper
389340
import sys
390341
import types
342+
391343
ultralytics_mock = types.SimpleNamespace(
392-
models=types.SimpleNamespace(yolo=types.SimpleNamespace(detect=types.SimpleNamespace(DetectionPredictor=lambda: types.SimpleNamespace(args=None)))),
393-
utils=types.SimpleNamespace(loss=types.SimpleNamespace(v8DetectionLoss=lambda m: None, E2EDetectLoss=lambda m: None))
344+
models=types.SimpleNamespace(
345+
yolo=types.SimpleNamespace(
346+
detect=types.SimpleNamespace(DetectionPredictor=lambda: types.SimpleNamespace(args=None))
347+
)
348+
),
349+
utils=types.SimpleNamespace(
350+
loss=types.SimpleNamespace(v8DetectionLoss=lambda m: None, E2EDetectLoss=lambda m: None)
351+
),
394352
)
395-
sys.modules['ultralytics'] = ultralytics_mock
396-
sys.modules['ultralytics.models'] = ultralytics_mock.models
397-
sys.modules['ultralytics.models.yolo'] = ultralytics_mock.models.yolo
398-
sys.modules['ultralytics.models.yolo.detect'] = ultralytics_mock.models.yolo.detect
399-
sys.modules['ultralytics.utils'] = ultralytics_mock.utils
400-
sys.modules['ultralytics.utils.loss'] = ultralytics_mock.utils.loss
401-
353+
sys.modules["ultralytics"] = ultralytics_mock
354+
sys.modules["ultralytics.models"] = ultralytics_mock.models
355+
sys.modules["ultralytics.models.yolo"] = ultralytics_mock.models.yolo
356+
sys.modules["ultralytics.models.yolo.detect"] = ultralytics_mock.models.yolo.detect
357+
sys.modules["ultralytics.utils"] = ultralytics_mock.utils
358+
sys.modules["ultralytics.utils.loss"] = ultralytics_mock.utils.loss
402359
wrapper = PyTorchYoloLossWrapper(dummy_model, name="yolov8n")
403-
wrapper.train()
404-
# Dummy input and targets
405-
x = torch.zeros((1, 3, 416, 416))
406-
targets = [{"boxes": torch.zeros((1, 4)), "labels": torch.zeros((1,))}]
407-
losses = wrapper(x, targets)
408-
assert set(losses.keys()) == {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
409-
assert losses["loss_total"].item() == 6.0 # sum([1.0, 2.0, 3.0])
410-
assert losses["loss_box"].item() == 1.0
411-
assert losses["loss_cls"].item() == 2.0
412-
assert losses["loss_dfl"].item() == 3.0
360+
assert isinstance(wrapper, PyTorchYoloLossWrapper)

0 commit comments

Comments
 (0)