@@ -367,7 +367,7 @@ def test_compute_loss(art_warning, get_pytorch_yolo):
367367 # Compute loss
368368 loss = object_detector .compute_loss (x = x_test , y = y_test )
369369
370- assert pytest .approx (11.20741 , abs = 0.9 ) == float (loss )
370+ assert pytest .approx (11.20741 , abs = 1.5 ) == float (loss )
371371
372372 except ARTTestException as e :
373373 art_warning (e )
@@ -386,3 +386,52 @@ def test_pgd(art_warning, get_pytorch_yolo):
386386
387387 except ARTTestException as e :
388388 art_warning (e )
389+
390+ @pytest .mark .only_with_platform ("pytorch" )
391+ def test_patch (art_warning , get_pytorch_yolo ):
392+ try :
393+
394+ from art .attacks .evasion import AdversarialPatchPyTorch
395+
396+ rotation_max = 0.0
397+ scale_min = 0.1
398+ scale_max = 0.3
399+ distortion_scale_max = 0.0
400+ learning_rate = 1.99
401+ max_iter = 2
402+ batch_size = 16
403+ patch_shape = (3 , 5 , 5 )
404+ patch_type = "circle"
405+ optimizer = "pgd"
406+
407+ object_detector , x_test , y_test = get_pytorch_yolo
408+
409+ ap = AdversarialPatchPyTorch (estimator = object_detector , rotation_max = rotation_max ,
410+ scale_min = scale_min , scale_max = scale_max , optimizer = optimizer , distortion_scale_max = distortion_scale_max ,
411+ learning_rate = learning_rate , max_iter = max_iter , batch_size = batch_size ,
412+ patch_shape = patch_shape , patch_type = patch_type , verbose = True , targeted = False )
413+
414+ _ , _ = ap .generate (x = x_test , y = y_test )
415+
416+ patched_images = ap .apply_patch (x_test , scale = 0.4 )
417+ result = object_detector .predict (patched_images )
418+
419+ assert result [0 ]["scores" ].shape == (10647 ,)
420+ expected_detection_scores = np .asarray (
421+ [
422+ 4.3653536e-08 ,
423+ 3.3987994e-06 ,
424+ 2.5681820e-06 ,
425+ 3.9782722e-06 ,
426+ 2.1766680e-05 ,
427+ 2.6138965e-05 ,
428+ 6.3377396e-05 ,
429+ 7.6248516e-06 ,
430+ 4.3447722e-06 ,
431+ 3.6515078e-06 ,
432+ ]
433+ )
434+ np .testing .assert_raises (AssertionError , np .testing .assert_array_almost_equal , result [0 ]["scores" ][:10 ], expected_detection_scores , 6 )
435+
436+ except ARTTestException as e :
437+ art_warning (e )
0 commit comments