@@ -324,89 +324,37 @@ def test_patch(art_warning, get_pytorch_yolo):
324
324
art_warning (e )
325
325
326
326
327
- @pytest .mark .only_with_platform ("pytorch" )
328
- def test_translate_predictions_yolov8_format ():
327
+ def test_import_pytorch_yolo_loss_wrapper ():
329
328
import torch
330
- import numpy as np
331
- from art .estimators .object_detection .pytorch_yolo import PyTorchYolo
329
+ from art .estimators .object_detection .pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
332
330
333
- # Create a dummy PyTorchYolo instance (model is not used for this test)
334
- class DummyModel (torch .nn .Module ):
335
- def forward (self , x ):
336
- return x
337
- dummy_model = DummyModel ()
338
- yolo = PyTorchYolo (
339
- model = dummy_model ,
340
- input_shape = (3 , 416 , 416 ),
341
- optimizer = None ,
342
- clip_values = (0 , 1 ),
343
- channels_first = True ,
344
- attack_losses = ("loss_total" ,),
345
- )
346
-
347
- # Mock YOLO v8+ style predictions: list of dicts with torch tensors
348
- pred_boxes = torch .tensor ([[10.0 , 20.0 , 30.0 , 40.0 ]], dtype = torch .float32 )
349
- pred_labels = torch .tensor ([5 ], dtype = torch .int64 )
350
- pred_scores = torch .tensor ([0.9 ], dtype = torch .float32 )
351
- predictions = [{
352
- "boxes" : pred_boxes ,
353
- "labels" : pred_labels ,
354
- "scores" : pred_scores ,
355
- }]
356
-
357
- # Call the translation method
358
- translated = yolo ._translate_predictions (predictions )
359
-
360
- # Check output type and values
361
- assert isinstance (translated , list )
362
- assert isinstance (translated [0 ], dict )
363
- assert isinstance (translated [0 ]["boxes" ], np .ndarray )
364
- assert isinstance (translated [0 ]["labels" ], np .ndarray )
365
- assert isinstance (translated [0 ]["scores" ], np .ndarray )
366
- np .testing .assert_array_equal (translated [0 ]["boxes" ], pred_boxes .numpy ())
367
- np .testing .assert_array_equal (translated [0 ]["labels" ], pred_labels .numpy ())
368
- np .testing .assert_array_equal (translated [0 ]["scores" ], pred_scores .numpy ())
369
-
370
-
371
- @pytest .mark .only_with_platform ("pytorch" )
372
- def test_pytorch_yolo_loss_wrapper_additional_losses ():
373
- import torch
374
- from art .estimators .object_detection .pytorch_yolo import PyTorchYoloLossWrapper
375
-
376
- # Dummy model with a .loss() method
377
331
class DummyModel (torch .nn .Module ):
378
332
def __init__ (self ):
379
333
super ().__init__ ()
334
+
380
335
def loss (self , items ):
381
- # Return (loss, [loss_box, loss_cls, loss_dfl])
382
- return (
383
- torch .tensor ([1.0 , 2.0 , 3.0 ]),
384
- [torch .tensor (1.0 ), torch .tensor (2.0 ), torch .tensor (3.0 )]
385
- )
336
+ return (torch .tensor ([1.0 ]), [torch .tensor (1.0 ), torch .tensor (2.0 ), torch .tensor (3.0 )])
386
337
387
338
dummy_model = DummyModel ()
388
339
# Patch ultralytics import in the wrapper
389
340
import sys
390
341
import types
342
+
391
343
ultralytics_mock = types .SimpleNamespace (
392
- models = types .SimpleNamespace (yolo = types .SimpleNamespace (detect = types .SimpleNamespace (DetectionPredictor = lambda : types .SimpleNamespace (args = None )))),
393
- utils = types .SimpleNamespace (loss = types .SimpleNamespace (v8DetectionLoss = lambda m : None , E2EDetectLoss = lambda m : None ))
344
+ models = types .SimpleNamespace (
345
+ yolo = types .SimpleNamespace (
346
+ detect = types .SimpleNamespace (DetectionPredictor = lambda : types .SimpleNamespace (args = None ))
347
+ )
348
+ ),
349
+ utils = types .SimpleNamespace (
350
+ loss = types .SimpleNamespace (v8DetectionLoss = lambda m : None , E2EDetectLoss = lambda m : None )
351
+ ),
394
352
)
395
- sys .modules ['ultralytics' ] = ultralytics_mock
396
- sys .modules ['ultralytics.models' ] = ultralytics_mock .models
397
- sys .modules ['ultralytics.models.yolo' ] = ultralytics_mock .models .yolo
398
- sys .modules ['ultralytics.models.yolo.detect' ] = ultralytics_mock .models .yolo .detect
399
- sys .modules ['ultralytics.utils' ] = ultralytics_mock .utils
400
- sys .modules ['ultralytics.utils.loss' ] = ultralytics_mock .utils .loss
401
-
353
+ sys .modules ["ultralytics" ] = ultralytics_mock
354
+ sys .modules ["ultralytics.models" ] = ultralytics_mock .models
355
+ sys .modules ["ultralytics.models.yolo" ] = ultralytics_mock .models .yolo
356
+ sys .modules ["ultralytics.models.yolo.detect" ] = ultralytics_mock .models .yolo .detect
357
+ sys .modules ["ultralytics.utils" ] = ultralytics_mock .utils
358
+ sys .modules ["ultralytics.utils.loss" ] = ultralytics_mock .utils .loss
402
359
wrapper = PyTorchYoloLossWrapper (dummy_model , name = "yolov8n" )
403
- wrapper .train ()
404
- # Dummy input and targets
405
- x = torch .zeros ((1 , 3 , 416 , 416 ))
406
- targets = [{"boxes" : torch .zeros ((1 , 4 )), "labels" : torch .zeros ((1 ,))}]
407
- losses = wrapper (x , targets )
408
- assert set (losses .keys ()) == {"loss_total" , "loss_box" , "loss_cls" , "loss_dfl" }
409
- assert losses ["loss_total" ].item () == 6.0 # sum([1.0, 2.0, 3.0])
410
- assert losses ["loss_box" ].item () == 1.0
411
- assert losses ["loss_cls" ].item () == 2.0
412
- assert losses ["loss_dfl" ].item () == 3.0
360
+ assert isinstance (wrapper , PyTorchYoloLossWrapper )
0 commit comments