@@ -34,7 +34,7 @@ def _get_adv_trainer():
3434 if framework in ["tensorflow" , "tensorflow2v1" ]:
3535 trainer = None
3636 if framework == "pytorch" :
37- classifier , _ = image_dl_estimator ()
37+ classifier , _ = image_dl_estimator (from_logits = True )
3838 attack = ProjectedGradientDescent (
3939 classifier ,
4040 norm = np .inf ,
@@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
6363 yield x_train_mnist [:n_train ], y_train_mnist [:n_train ], x_test_mnist [:n_test ], y_test_mnist [:n_test ]
6464
6565
66- @pytest .mark .skip_framework ("tensorflow" , "keras" , "scikitlearn" , "mxnet" , "kerastf" )
67- def test_adversarial_trainer_trades_pytorch_fit_and_predict (get_adv_trainer , fix_get_mnist_subset ):
66+ @pytest .mark .only_with_platform ("pytorch" )
67+ @pytest .mark .parametrize ("label_format" , ["one_hot" , "numerical" ])
68+ def test_adversarial_trainer_trades_pytorch_fit_and_predict (get_adv_trainer , fix_get_mnist_subset , label_format ):
6869 (x_train_mnist , y_train_mnist , x_test_mnist , y_test_mnist ) = fix_get_mnist_subset
6970 x_test_mnist_original = x_test_mnist .copy ()
7071
72+ if label_format == "one_hot" :
73+ assert y_train_mnist .shape [- 1 ] == 10
74+ assert y_test_mnist .shape [- 1 ] == 10
75+ if label_format == "numerical" :
76+ y_test_mnist = np .argmax (y_test_mnist , axis = 1 )
77+ y_train_mnist = np .argmax (y_train_mnist , axis = 1 )
78+
7179 trainer = get_adv_trainer ()
7280 if trainer is None :
7381 logging .warning ("Couldn't perform this test because no trainer is defined for this framework configuration" )
7482 return
7583
7684 predictions = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
77- accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
85+
86+ if label_format == "one_hot" :
87+ accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
88+ else :
89+ accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
7890
7991 trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 )
8092 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
81- accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
93+
94+ if label_format == "one_hot" :
95+ accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
96+ else :
97+ accuracy_new = np .sum (predictions_new == y_test_mnist ) / x_test_mnist .shape [0 ]
8298
8399 np .testing .assert_array_almost_equal (
84100 float (np .mean (x_test_mnist_original - x_test_mnist )),
@@ -92,13 +108,20 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
92108 trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
93109
94110
95- @pytest .mark .skip_framework ("tensorflow" , "keras" , "scikitlearn" , "mxnet" , "kerastf" )
111+ @pytest .mark .only_with_platform ("pytorch" )
112+ @pytest .mark .parametrize ("label_format" , ["one_hot" , "numerical" ])
96113def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict (
97- get_adv_trainer , fix_get_mnist_subset , image_data_generator
114+ get_adv_trainer , fix_get_mnist_subset , image_data_generator , label_format
98115):
99116 (x_train_mnist , y_train_mnist , x_test_mnist , y_test_mnist ) = fix_get_mnist_subset
100117 x_test_mnist_original = x_test_mnist .copy ()
101118
119+ if label_format == "one_hot" :
120+ assert y_train_mnist .shape [- 1 ] == 10
121+ assert y_test_mnist .shape [- 1 ] == 10
122+ if label_format == "numerical" :
123+ y_test_mnist = np .argmax (y_test_mnist , axis = 1 )
124+
102125 generator = image_data_generator ()
103126
104127 trainer = get_adv_trainer ()
@@ -107,11 +130,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
107130 return
108131
109132 predictions = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
110- accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
133+ if label_format == "one_hot" :
134+ accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
135+ else :
136+ accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
111137
112138 trainer .fit_generator (generator = generator , nb_epochs = 20 )
113139 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
114- accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
140+
141+ if label_format == "one_hot" :
142+ accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
143+ else :
144+ accuracy_new = np .sum (predictions_new == y_test_mnist ) / x_test_mnist .shape [0 ]
115145
116146 np .testing .assert_array_almost_equal (
117147 float (np .mean (x_test_mnist_original - x_test_mnist )),
0 commit comments