Skip to content

Commit 0720265

Browse files
committed
Optimise test time
Signed-off-by: Beat Buesser <[email protected]>
1 parent 5d06f1d commit 0720265

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/defences/trainer/test_adversarial_trainer_oaat_pytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ def _get_adv_trainer_oaat():
6969
classifier,
7070
norm=np.inf,
7171
eps=0.2,
72-
eps_step=0.02,
73-
max_iter=20,
72+
eps_step=0.01,
73+
max_iter=5,
7474
targeted=False,
75-
num_random_init=1,
75+
num_random_init=0,
7676
batch_size=16,
7777
verbose=False,
7878
)
@@ -120,7 +120,7 @@ def test_adversarial_trainer_oaat_pytorch_fit_and_predict(get_adv_trainer_oaat,
120120
else:
121121
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
122122

123-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=5, batch_size=16)
123+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=10, batch_size=16)
124124
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
125125

126126
if label_format == "one_hot":
@@ -170,7 +170,7 @@ def test_adversarial_trainer_oaat_pytorch_fit_generator_and_predict(
170170
else:
171171
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
172172

173-
trainer.fit_generator(generator=generator, nb_epochs=10)
173+
trainer.fit_generator(generator=generator, nb_epochs=5)
174174
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
175175

176176
if label_format == "one_hot":

0 commit comments

Comments
 (0)