@@ -322,3 +322,91 @@ def test_patch(art_warning, get_pytorch_yolo):
322322
323323 except ARTTestException as e :
324324 art_warning (e )
325+
326+
327+ @pytest .mark .only_with_platform ("pytorch" )
328+ def test_translate_predictions_yolov8_format ():
329+ import torch
330+ import numpy as np
331+ from art .estimators .object_detection .pytorch_yolo import PyTorchYolo
332+
333+ # Create a dummy PyTorchYolo instance (model is not used for this test)
334+ class DummyModel (torch .nn .Module ):
335+ def forward (self , x ):
336+ return x
337+ dummy_model = DummyModel ()
338+ yolo = PyTorchYolo (
339+ model = dummy_model ,
340+ input_shape = (3 , 416 , 416 ),
341+ optimizer = None ,
342+ clip_values = (0 , 1 ),
343+ channels_first = True ,
344+ attack_losses = ("loss_total" ,),
345+ )
346+
347+ # Mock YOLO v8+ style predictions: list of dicts with torch tensors
348+ pred_boxes = torch .tensor ([[10.0 , 20.0 , 30.0 , 40.0 ]], dtype = torch .float32 )
349+ pred_labels = torch .tensor ([5 ], dtype = torch .int64 )
350+ pred_scores = torch .tensor ([0.9 ], dtype = torch .float32 )
351+ predictions = [{
352+ "boxes" : pred_boxes ,
353+ "labels" : pred_labels ,
354+ "scores" : pred_scores ,
355+ }]
356+
357+ # Call the translation method
358+ translated = yolo ._translate_predictions (predictions )
359+
360+ # Check output type and values
361+ assert isinstance (translated , list )
362+ assert isinstance (translated [0 ], dict )
363+ assert isinstance (translated [0 ]["boxes" ], np .ndarray )
364+ assert isinstance (translated [0 ]["labels" ], np .ndarray )
365+ assert isinstance (translated [0 ]["scores" ], np .ndarray )
366+ np .testing .assert_array_equal (translated [0 ]["boxes" ], pred_boxes .numpy ())
367+ np .testing .assert_array_equal (translated [0 ]["labels" ], pred_labels .numpy ())
368+ np .testing .assert_array_equal (translated [0 ]["scores" ], pred_scores .numpy ())
369+
370+
371+ @pytest .mark .only_with_platform ("pytorch" )
372+ def test_pytorch_yolo_loss_wrapper_additional_losses ():
373+ import torch
374+ from art .estimators .object_detection .pytorch_yolo import PyTorchYoloLossWrapper
375+
376+ # Dummy model with a .loss() method
377+ class DummyModel (torch .nn .Module ):
378+ def __init__ (self ):
379+ super ().__init__ ()
380+ def loss (self , items ):
381+ # Return (loss, [loss_box, loss_cls, loss_dfl])
382+ return (
383+ torch .tensor ([1.0 , 2.0 , 3.0 ]),
384+ [torch .tensor (1.0 ), torch .tensor (2.0 ), torch .tensor (3.0 )]
385+ )
386+
387+ dummy_model = DummyModel ()
388+ # Patch ultralytics import in the wrapper
389+ import sys
390+ import types
391+ ultralytics_mock = types .SimpleNamespace (
392+ models = types .SimpleNamespace (yolo = types .SimpleNamespace (detect = types .SimpleNamespace (DetectionPredictor = lambda : types .SimpleNamespace (args = None )))),
393+ utils = types .SimpleNamespace (loss = types .SimpleNamespace (v8DetectionLoss = lambda m : None , E2EDetectLoss = lambda m : None ))
394+ )
395+ sys .modules ['ultralytics' ] = ultralytics_mock
396+ sys .modules ['ultralytics.models' ] = ultralytics_mock .models
397+ sys .modules ['ultralytics.models.yolo' ] = ultralytics_mock .models .yolo
398+ sys .modules ['ultralytics.models.yolo.detect' ] = ultralytics_mock .models .yolo .detect
399+ sys .modules ['ultralytics.utils' ] = ultralytics_mock .utils
400+ sys .modules ['ultralytics.utils.loss' ] = ultralytics_mock .utils .loss
401+
402+ wrapper = PyTorchYoloLossWrapper (dummy_model , name = "yolov8n" )
403+ wrapper .train ()
404+ # Dummy input and targets
405+ x = torch .zeros ((1 , 3 , 416 , 416 ))
406+ targets = [{"boxes" : torch .zeros ((1 , 4 )), "labels" : torch .zeros ((1 ,))}]
407+ losses = wrapper (x , targets )
408+ assert set (losses .keys ()) == {"loss_total" , "loss_box" , "loss_cls" , "loss_dfl" }
409+ assert losses ["loss_total" ].item () == 6.0 # sum([1.0, 2.0, 3.0])
410+ assert losses ["loss_box" ].item () == 1.0
411+ assert losses ["loss_cls" ].item () == 2.0
412+ assert losses ["loss_dfl" ].item () == 3.0
0 commit comments