@@ -74,9 +74,9 @@ def _get_adv_trainer_awptrades():
74
74
norm = np .inf ,
75
75
eps = 0.2 ,
76
76
eps_step = 0.02 ,
77
- max_iter = 20 ,
77
+ max_iter = 5 ,
78
78
targeted = False ,
79
- num_random_init = 1 ,
79
+ num_random_init = 0 ,
80
80
batch_size = 128 ,
81
81
verbose = False ,
82
82
)
@@ -141,7 +141,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_and_predict(get_adv_trainer_awpp
141
141
assert accuracy == 0.32
142
142
assert accuracy_new > 0.32
143
143
144
- trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
144
+ trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 5 , validation_data = (x_train_mnist , y_train_mnist ))
145
145
146
146
147
147
@pytest .mark .only_with_platform ("pytorch" )
@@ -171,7 +171,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_and_predict(
171
171
else :
172
172
accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
173
173
174
- trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 )
174
+ trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 5 )
175
175
predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
176
176
177
177
if label_format == "one_hot" :
@@ -188,7 +188,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_and_predict(
188
188
assert accuracy == 0.32
189
189
assert accuracy_new > 0.32
190
190
191
- trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
191
+ trainer .fit (x_train_mnist , y_train_mnist , nb_epochs = 5 , validation_data = (x_train_mnist , y_train_mnist ))
192
192
193
193
194
194
@pytest .mark .only_with_platform ("pytorch" )
@@ -219,7 +219,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_generator_and_predict(
219
219
else :
220
220
accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
221
221
222
- trainer .fit_generator (generator = generator , nb_epochs = 20 )
222
+ trainer .fit_generator (generator = generator , nb_epochs = 5 )
223
223
predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
224
224
225
225
if label_format == "one_hot" :
@@ -236,7 +236,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_generator_and_predict(
236
236
assert accuracy == 0.32
237
237
assert accuracy_new > 0.32
238
238
239
- trainer .fit_generator (generator = generator , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
239
+ trainer .fit_generator (generator = generator , nb_epochs = 2 , validation_data = (x_train_mnist , y_train_mnist ))
240
240
241
241
242
242
@pytest .mark .only_with_platform ("pytorch" )
@@ -267,7 +267,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_generator_and_predict(
267
267
else :
268
268
accuracy = np .sum (predictions == y_test_mnist ) / x_test_mnist .shape [0 ]
269
269
270
- trainer .fit_generator (generator = generator , nb_epochs = 20 )
270
+ trainer .fit_generator (generator = generator , nb_epochs = 5 )
271
271
predictions_new = np .argmax (trainer .predict (x_test_mnist ), axis = 1 )
272
272
273
273
if label_format == "one_hot" :
@@ -284,4 +284,4 @@ def test_adversarial_trainer_awptrades_pytorch_fit_generator_and_predict(
284
284
assert accuracy == 0.32
285
285
assert accuracy_new > 0.32
286
286
287
- trainer .fit_generator (generator = generator , nb_epochs = 20 , validation_data = (x_train_mnist , y_train_mnist ))
287
+ trainer .fit_generator (generator = generator , nb_epochs = 2 , validation_data = (x_train_mnist , y_train_mnist ))
0 commit comments