@@ -387,29 +387,41 @@ def test_pgd(art_warning, get_pytorch_yolo):
387387 except ARTTestException as e :
388388 art_warning (e )
389389
390+
390391@pytest .mark .only_with_platform ("pytorch" )
391392def test_patch (art_warning , get_pytorch_yolo ):
392393 try :
393-
394+
394395 from art .attacks .evasion import AdversarialPatchPyTorch
395396
396- rotation_max = 0.0
397- scale_min = 0.1
398- scale_max = 0.3
399- distortion_scale_max = 0.0
400- learning_rate = 1.99
401- max_iter = 2
402- batch_size = 16
403- patch_shape = (3 , 5 , 5 )
404- patch_type = "circle"
405- optimizer = "pgd"
397+ rotation_max = 0.0
398+ scale_min = 0.1
399+ scale_max = 0.3
400+ distortion_scale_max = 0.0
401+ learning_rate = 1.99
402+ max_iter = 2
403+ batch_size = 16
404+ patch_shape = (3 , 5 , 5 )
405+ patch_type = "circle"
406+ optimizer = "pgd"
406407
407408 object_detector , x_test , y_test = get_pytorch_yolo
408409
409- ap = AdversarialPatchPyTorch (estimator = object_detector , rotation_max = rotation_max ,
410- scale_min = scale_min , scale_max = scale_max , optimizer = optimizer , distortion_scale_max = distortion_scale_max ,
411- learning_rate = learning_rate , max_iter = max_iter , batch_size = batch_size ,
412- patch_shape = patch_shape , patch_type = patch_type , verbose = True , targeted = False )
410+ ap = AdversarialPatchPyTorch (
411+ estimator = object_detector ,
412+ rotation_max = rotation_max ,
413+ scale_min = scale_min ,
414+ scale_max = scale_max ,
415+ optimizer = optimizer ,
416+ distortion_scale_max = distortion_scale_max ,
417+ learning_rate = learning_rate ,
418+ max_iter = max_iter ,
419+ batch_size = batch_size ,
420+ patch_shape = patch_shape ,
421+ patch_type = patch_type ,
422+ verbose = True ,
423+ targeted = False ,
424+ )
413425
414426 _ , _ = ap .generate (x = x_test , y = y_test )
415427
@@ -431,7 +443,9 @@ def test_patch(art_warning, get_pytorch_yolo):
431443 3.6515078e-06 ,
432444 ]
433445 )
434- np .testing .assert_raises (AssertionError , np .testing .assert_array_almost_equal , result [0 ]["scores" ][:10 ], expected_detection_scores , 6 )
446+ np .testing .assert_raises (
447+ AssertionError , np .testing .assert_array_almost_equal , result [0 ]["scores" ][:10 ], expected_detection_scores , 6
448+ )
435449
436450 except ARTTestException as e :
437- art_warning (e )
451+ art_warning (e )
0 commit comments