@@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
6363 yield x_train_mnist [:n_train ], y_train_mnist [:n_train ], x_test_mnist [:n_test ], y_test_mnist [:n_test ]
6464
6565
66- @pytest .mark .skip_framework ("tensorflow" , "keras" , "scikitlearn" , "mxnet" , "kerastf" )
67- def test_adversarial_trainer_trades_pytorch_fit_and_predict (get_adv_trainer , fix_get_mnist_subset ):
66+ @pytest .mark .only_with_platform ("pytorch" )
67+ @pytest .mark .parametrize ("label_format" , ["one_hot" , "numerical" ])
68+ def test_adversarial_trainer_trades_pytorch_fit_and_predict (get_adv_trainer , fix_get_mnist_subset , label_format ):
6869 (x_train_mnist , y_train_mnist , x_test_mnist , y_test_mnist ) = fix_get_mnist_subset
6970 x_test_mnist_original = x_test_mnist .copy ()
7071
72+ if label_format == "one_hot" :
73+ assert y_train_mnist .shape [- 1 ] == 10
74+ assert y_test_mnist .shape [- 1 ] == 10
75+ if label_format == "numerical" :
76+ y_test_mnist = np .argmax (y_test_mnist , axis = 1 )
77+ y_train_mnist = np .argmax (y_train_mnist , axis = 1 )
78+
7179 trainer = get_adv_trainer ()
7280 if trainer is None :
7381 logging .warning ("Couldn't perform this test because no trainer is defined for this framework configuration" )
7482 return
7583
7684 predictions = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
77- accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
85+
86+ if label_format == "one_hot" :
87+ accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
88+ else :
89+ accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
7890
7991 trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 )
8092 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
81- accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
93+
94+ if label_format == "one_hot" :
95+ accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
96+ else :
97+ accuracy_new = np .sum (predictions_new == y_test_mnist ) / x_test_mnist .shape [0 ]
8298
8399 np .testing .assert_array_almost_equal (
84100 float (np .mean (x_test_mnist_original - x_test_mnist )),
@@ -92,13 +108,21 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
92108 trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
93109
94110
95- @pytest .mark .skip_framework ("tensorflow" , "keras" , "scikitlearn" , "mxnet" , "kerastf" )
111+ @pytest .mark .only_with_platform ("pytorch" )
112+ @pytest .mark .parametrize ("label_format" , ["one_hot" , "numerical" ])
96113def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict (
97- get_adv_trainer , fix_get_mnist_subset , image_data_generator
114+ get_adv_trainer , fix_get_mnist_subset , image_data_generator , label_format
98115):
99116 (x_train_mnist , y_train_mnist , x_test_mnist , y_test_mnist ) = fix_get_mnist_subset
100117 x_test_mnist_original = x_test_mnist .copy ()
101118
119+ if label_format == "one_hot" :
120+ assert y_train_mnist .shape [- 1 ] == 10
121+ assert y_test_mnist .shape [- 1 ] == 10
122+ if label_format == "numerical" :
123+ y_test_mnist = np .argmax (y_test_mnist , axis = 1 )
124+ y_train_mnist = np .argmax (y_train_mnist , axis = 1 )
125+
102126 generator = image_data_generator ()
103127
104128 trainer = get_adv_trainer ()
@@ -107,11 +131,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
107131 return
108132
109133 predictions = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
110- accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
134+ if label_format == "one_hot" :
135+ accuracy = np .sum (predictions == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
136+ else :
137+ accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
111138
112139 trainer .fit_generator (generator = generator , nb_epochs = 20 )
113140 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
114- accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
141+
142+ if label_format == "one_hot" :
143+ accuracy_new = np .sum (predictions_new == np .argmax (y_test_mnist , axis = 1 )) / x_test_mnist .shape [0 ]
144+ else :
145+ accuracy_new = np .sum (predictions_new == y_test_mnist ) / x_test_mnist .shape [0 ]
115146
116147 np .testing .assert_array_almost_equal (
117148 float (np .mean (x_test_mnist_original - x_test_mnist )),
0 commit comments