@@ -322,3 +322,91 @@ def test_patch(art_warning, get_pytorch_yolo):
322
322
323
323
except ARTTestException as e :
324
324
art_warning (e )
325
+
326
+
327
+ @pytest .mark .only_with_platform ("pytorch" )
328
+ def test_translate_predictions_yolov8_format ():
329
+ import torch
330
+ import numpy as np
331
+ from art .estimators .object_detection .pytorch_yolo import PyTorchYolo
332
+
333
+ # Create a dummy PyTorchYolo instance (model is not used for this test)
334
+ class DummyModel (torch .nn .Module ):
335
+ def forward (self , x ):
336
+ return x
337
+ dummy_model = DummyModel ()
338
+ yolo = PyTorchYolo (
339
+ model = dummy_model ,
340
+ input_shape = (3 , 416 , 416 ),
341
+ optimizer = None ,
342
+ clip_values = (0 , 1 ),
343
+ channels_first = True ,
344
+ attack_losses = ("loss_total" ,),
345
+ )
346
+
347
+ # Mock YOLO v8+ style predictions: list of dicts with torch tensors
348
+ pred_boxes = torch .tensor ([[10.0 , 20.0 , 30.0 , 40.0 ]], dtype = torch .float32 )
349
+ pred_labels = torch .tensor ([5 ], dtype = torch .int64 )
350
+ pred_scores = torch .tensor ([0.9 ], dtype = torch .float32 )
351
+ predictions = [{
352
+ "boxes" : pred_boxes ,
353
+ "labels" : pred_labels ,
354
+ "scores" : pred_scores ,
355
+ }]
356
+
357
+ # Call the translation method
358
+ translated = yolo ._translate_predictions (predictions )
359
+
360
+ # Check output type and values
361
+ assert isinstance (translated , list )
362
+ assert isinstance (translated [0 ], dict )
363
+ assert isinstance (translated [0 ]["boxes" ], np .ndarray )
364
+ assert isinstance (translated [0 ]["labels" ], np .ndarray )
365
+ assert isinstance (translated [0 ]["scores" ], np .ndarray )
366
+ np .testing .assert_array_equal (translated [0 ]["boxes" ], pred_boxes .numpy ())
367
+ np .testing .assert_array_equal (translated [0 ]["labels" ], pred_labels .numpy ())
368
+ np .testing .assert_array_equal (translated [0 ]["scores" ], pred_scores .numpy ())
369
+
370
+
371
+ @pytest .mark .only_with_platform ("pytorch" )
372
+ def test_pytorch_yolo_loss_wrapper_additional_losses ():
373
+ import torch
374
+ from art .estimators .object_detection .pytorch_yolo import PyTorchYoloLossWrapper
375
+
376
+ # Dummy model with a .loss() method
377
+ class DummyModel (torch .nn .Module ):
378
+ def __init__ (self ):
379
+ super ().__init__ ()
380
+ def loss (self , items ):
381
+ # Return (loss, [loss_box, loss_cls, loss_dfl])
382
+ return (
383
+ torch .tensor ([1.0 , 2.0 , 3.0 ]),
384
+ [torch .tensor (1.0 ), torch .tensor (2.0 ), torch .tensor (3.0 )]
385
+ )
386
+
387
+ dummy_model = DummyModel ()
388
+ # Patch ultralytics import in the wrapper
389
+ import sys
390
+ import types
391
+ ultralytics_mock = types .SimpleNamespace (
392
+ models = types .SimpleNamespace (yolo = types .SimpleNamespace (detect = types .SimpleNamespace (DetectionPredictor = lambda : types .SimpleNamespace (args = None )))),
393
+ utils = types .SimpleNamespace (loss = types .SimpleNamespace (v8DetectionLoss = lambda m : None , E2EDetectLoss = lambda m : None ))
394
+ )
395
+ sys .modules ['ultralytics' ] = ultralytics_mock
396
+ sys .modules ['ultralytics.models' ] = ultralytics_mock .models
397
+ sys .modules ['ultralytics.models.yolo' ] = ultralytics_mock .models .yolo
398
+ sys .modules ['ultralytics.models.yolo.detect' ] = ultralytics_mock .models .yolo .detect
399
+ sys .modules ['ultralytics.utils' ] = ultralytics_mock .utils
400
+ sys .modules ['ultralytics.utils.loss' ] = ultralytics_mock .utils .loss
401
+
402
+ wrapper = PyTorchYoloLossWrapper (dummy_model , name = "yolov8n" )
403
+ wrapper .train ()
404
+ # Dummy input and targets
405
+ x = torch .zeros ((1 , 3 , 416 , 416 ))
406
+ targets = [{"boxes" : torch .zeros ((1 , 4 )), "labels" : torch .zeros ((1 ,))}]
407
+ losses = wrapper (x , targets )
408
+ assert set (losses .keys ()) == {"loss_total" , "loss_box" , "loss_cls" , "loss_dfl" }
409
+ assert losses ["loss_total" ].item () == 6.0 # sum([1.0, 2.0, 3.0])
410
+ assert losses ["loss_box" ].item () == 1.0
411
+ assert losses ["loss_cls" ].item () == 2.0
412
+ assert losses ["loss_dfl" ].item () == 3.0
0 commit comments