Skip to content

Commit 30e57d1

Browse files
committed
Yolo V8+ Changes
Signed-off-by: [email protected] <[email protected]>
1 parent c0b51d6 commit 30e57d1

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def forward(self, x, targets=None):
7070
boxes.append(item['boxes'])
7171
labels.append(item['labels'])
7272
indices = indices + ([i]*len(item['labels']))
73-
items = {'boxes': torch.concatenate(boxes) / x.shape[2],
74-
'labels': torch.concatenate(labels).type(torch.float32),
73+
items = {'boxes': torch.cat(boxes) / x.shape[2],
74+
'labels': torch.cat(labels).type(torch.float32),
7575
'batch_idx': torch.tensor(indices)}
7676
items['bboxes'] = items.pop('boxes')
7777
items['cls'] = items.pop('labels')

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,91 @@ def test_patch(art_warning, get_pytorch_yolo):
322322

323323
except ARTTestException as e:
324324
art_warning(e)
325+
326+
327+
@pytest.mark.only_with_platform("pytorch")
328+
def test_translate_predictions_yolov8_format():
329+
import torch
330+
import numpy as np
331+
from art.estimators.object_detection.pytorch_yolo import PyTorchYolo
332+
333+
# Create a dummy PyTorchYolo instance (model is not used for this test)
334+
class DummyModel(torch.nn.Module):
335+
def forward(self, x):
336+
return x
337+
dummy_model = DummyModel()
338+
yolo = PyTorchYolo(
339+
model=dummy_model,
340+
input_shape=(3, 416, 416),
341+
optimizer=None,
342+
clip_values=(0, 1),
343+
channels_first=True,
344+
attack_losses=("loss_total",),
345+
)
346+
347+
# Mock YOLO v8+ style predictions: list of dicts with torch tensors
348+
pred_boxes = torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)
349+
pred_labels = torch.tensor([5], dtype=torch.int64)
350+
pred_scores = torch.tensor([0.9], dtype=torch.float32)
351+
predictions = [{
352+
"boxes": pred_boxes,
353+
"labels": pred_labels,
354+
"scores": pred_scores,
355+
}]
356+
357+
# Call the translation method
358+
translated = yolo._translate_predictions(predictions)
359+
360+
# Check output type and values
361+
assert isinstance(translated, list)
362+
assert isinstance(translated[0], dict)
363+
assert isinstance(translated[0]["boxes"], np.ndarray)
364+
assert isinstance(translated[0]["labels"], np.ndarray)
365+
assert isinstance(translated[0]["scores"], np.ndarray)
366+
np.testing.assert_array_equal(translated[0]["boxes"], pred_boxes.numpy())
367+
np.testing.assert_array_equal(translated[0]["labels"], pred_labels.numpy())
368+
np.testing.assert_array_equal(translated[0]["scores"], pred_scores.numpy())
369+
370+
371+
@pytest.mark.only_with_platform("pytorch")
372+
def test_pytorch_yolo_loss_wrapper_additional_losses():
373+
import torch
374+
from art.estimators.object_detection.pytorch_yolo import PyTorchYoloLossWrapper
375+
376+
# Dummy model with a .loss() method
377+
class DummyModel(torch.nn.Module):
378+
def __init__(self):
379+
super().__init__()
380+
def loss(self, items):
381+
# Return (loss, [loss_box, loss_cls, loss_dfl])
382+
return (
383+
torch.tensor([1.0, 2.0, 3.0]),
384+
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]
385+
)
386+
387+
dummy_model = DummyModel()
388+
# Patch ultralytics import in the wrapper
389+
import sys
390+
import types
391+
ultralytics_mock = types.SimpleNamespace(
392+
models=types.SimpleNamespace(yolo=types.SimpleNamespace(detect=types.SimpleNamespace(DetectionPredictor=lambda: types.SimpleNamespace(args=None)))),
393+
utils=types.SimpleNamespace(loss=types.SimpleNamespace(v8DetectionLoss=lambda m: None, E2EDetectLoss=lambda m: None))
394+
)
395+
sys.modules['ultralytics'] = ultralytics_mock
396+
sys.modules['ultralytics.models'] = ultralytics_mock.models
397+
sys.modules['ultralytics.models.yolo'] = ultralytics_mock.models.yolo
398+
sys.modules['ultralytics.models.yolo.detect'] = ultralytics_mock.models.yolo.detect
399+
sys.modules['ultralytics.utils'] = ultralytics_mock.utils
400+
sys.modules['ultralytics.utils.loss'] = ultralytics_mock.utils.loss
401+
402+
wrapper = PyTorchYoloLossWrapper(dummy_model, name="yolov8n")
403+
wrapper.train()
404+
# Dummy input and targets
405+
x = torch.zeros((1, 3, 416, 416))
406+
targets = [{"boxes": torch.zeros((1, 4)), "labels": torch.zeros((1,))}]
407+
losses = wrapper(x, targets)
408+
assert set(losses.keys()) == {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
409+
assert losses["loss_total"].item() == 6.0 # sum([1.0, 2.0, 3.0])
410+
assert losses["loss_box"].item() == 1.0
411+
assert losses["loss_cls"].item() == 2.0
412+
assert losses["loss_dfl"].item() == 3.0

0 commit comments

Comments
 (0)