@@ -34,18 +34,13 @@ def fix_get_mnist_subset(get_mnist_dataset):
3434 (x_train_mnist , y_train_mnist ), (x_test_mnist , y_test_mnist ) = get_mnist_dataset
3535 n_train = 100
3636 n_test = 11
37- yield ( x_train_mnist [:n_train ], y_train_mnist [:n_train ], x_test_mnist [:n_test ], y_test_mnist [:n_test ])
37+ yield x_train_mnist [:n_train ], y_train_mnist [:n_train ], x_test_mnist [:n_test ], y_test_mnist [:n_test ]
3838
3939
40- @pytest .mark .only_with_platform ("pytorch" )
4140def test_generate (fix_get_mnist_subset , get_image_classifier_list_for_attack ):
4241
4342 classifier_list = get_image_classifier_list_for_attack (ShadowAttack )
4443
45- if classifier_list is None :
46- logging .warning ("Couldn't perform this test because no classifier is defined" )
47- return
48-
4944 for classifier in classifier_list :
5045 attack = ShadowAttack (
5146 estimator = classifier ,
@@ -61,15 +56,11 @@ def test_generate(fix_get_mnist_subset, get_image_classifier_list_for_attack):
6156
6257 (x_train_mnist , y_train_mnist , x_test_mnist , y_test_mnist ) = fix_get_mnist_subset
6358
64- if attack .framework == "pytorch" :
65- x_train_mnist = x_train_mnist .transpose ((0 , 3 , 1 , 2 ))
66-
6759 x_train_mnist_adv = attack .generate (x = x_train_mnist [0 :1 ], y = y_train_mnist [0 :1 ])
6860
6961 assert np .max (np .abs (x_train_mnist_adv - x_train_mnist [0 :1 ])) == pytest .approx (0.34966960549354553 , 0.06 )
7062
7163
72- @pytest .mark .only_with_platform ("pytorch" )
7364def test_get_regularisation_loss_gradients (fix_get_mnist_subset , get_image_classifier_list_for_attack ):
7465
7566 classifier_list = get_image_classifier_list_for_attack (ShadowAttack )
@@ -90,9 +81,6 @@ def test_get_regularisation_loss_gradients(fix_get_mnist_subset, get_image_class
9081
9182 (x_train_mnist , _ , _ , _ ) = fix_get_mnist_subset
9283
93- if attack .framework == "pytorch" :
94- x_train_mnist = x_train_mnist .transpose ((0 , 3 , 1 , 2 ))
95-
9684 gradients = attack ._get_regularisation_loss_gradients (x_train_mnist [0 :1 ])
9785
9886 gradients_expected = np .array (
0 commit comments