@@ -74,9 +74,9 @@ def _get_adv_trainer_awptrades():
7474 norm = np .inf ,
7575 eps = 0.2 ,
7676 eps_step = 0.02 ,
77- max_iter = 20 ,
77+ max_iter = 5 ,
7878 targeted = False ,
79- num_random_init = 1 ,
79+ num_random_init = 0 ,
8080 batch_size = 128 ,
8181 verbose = False ,
8282 )
@@ -141,7 +141,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_and_predict(get_adv_trainer_awpp
141141 assert accuracy == 0.32
142142 assert accuracy_new > 0.32
143143
144- trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
144+ trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 5 , validation_data = (x_train_mnist , y_train_mnist ))
145145
146146
147147@pytest .mark .only_with_platform ("pytorch" )
@@ -171,7 +171,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_and_predict(
171171 else :
172172 accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
173173
174- trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 )
174+ trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 5 )
175175 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
176176
177177 if label_format == "one_hot" :
@@ -188,7 +188,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_and_predict(
188188 assert accuracy == 0.32
189189 assert accuracy_new > 0.32
190190
191- trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
191+ trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 5 , validation_data = (x_train_mnist , y_train_mnist ))
192192
193193
194194@pytest .mark .only_with_platform ("pytorch" )
@@ -219,7 +219,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_generator_and_predict(
219219 else :
220220 accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
221221
222- trainer .fit_generator (generator = generator , nb_epochs = 20 )
222+ trainer .fit_generator (generator = generator , nb_epochs = 5 )
223223 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
224224
225225 if label_format == "one_hot" :
@@ -236,7 +236,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_generator_and_predict(
236236 assert accuracy == 0.32
237237 assert accuracy_new > 0.32
238238
239- trainer .fit_generator (generator = generator , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
239+ trainer .fit_generator (generator = generator , nb_epochs = 2 , validation_data = (x_train_mnist , y_train_mnist ))
240240
241241
242242@pytest .mark .only_with_platform ("pytorch" )
@@ -267,7 +267,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_generator_and_predict(
267267 else :
268268 accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
269269
270- trainer .fit_generator (generator = generator , nb_epochs = 20 )
270+ trainer .fit_generator (generator = generator , nb_epochs = 5 )
271271 predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
272272
273273 if label_format == "one_hot" :
@@ -284,4 +284,4 @@ def test_adversarial_trainer_awptrades_pytorch_fit_generator_and_predict(
284284 assert accuracy == 0.32
285285 assert accuracy_new > 0.32
286286
287- trainer .fit_generator (generator = generator , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
287+ trainer .fit_generator (generator = generator , nb_epochs = 2 , validation_data = (x_train_mnist , y_train_mnist ))
0 commit comments