Skip to content

Commit 5d06f1d

Browse files
committed
Optimise test time
Signed-off-by: Beat Buesser <[email protected]>
1 parent 8323383 commit 5d06f1d

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

tests/defences/trainer/test_adversarial_trainer_awp_pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def _get_adv_trainer_awptrades():
7474
norm=np.inf,
7575
eps=0.2,
7676
eps_step=0.02,
77-
max_iter=20,
77+
max_iter=5,
7878
targeted=False,
79-
num_random_init=1,
79+
num_random_init=0,
8080
batch_size=128,
8181
verbose=False,
8282
)
@@ -141,7 +141,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_and_predict(get_adv_trainer_awpp
141141
assert accuracy == 0.32
142142
assert accuracy_new > 0.32
143143

144-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
144+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=5, validation_data=(x_train_mnist, y_train_mnist))
145145

146146

147147
@pytest.mark.only_with_platform("pytorch")
@@ -171,7 +171,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_and_predict(
171171
else:
172172
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
173173

174-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
174+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=5)
175175
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
176176

177177
if label_format == "one_hot":
@@ -188,7 +188,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_and_predict(
188188
assert accuracy == 0.32
189189
assert accuracy_new > 0.32
190190

191-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
191+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=5, validation_data=(x_train_mnist, y_train_mnist))
192192

193193

194194
@pytest.mark.only_with_platform("pytorch")
@@ -219,7 +219,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_generator_and_predict(
219219
else:
220220
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
221221

222-
trainer.fit_generator(generator=generator, nb_epochs=20)
222+
trainer.fit_generator(generator=generator, nb_epochs=5)
223223
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
224224

225225
if label_format == "one_hot":
@@ -236,7 +236,7 @@ def test_adversarial_trainer_awppgd_pytorch_fit_generator_and_predict(
236236
assert accuracy == 0.32
237237
assert accuracy_new > 0.32
238238

239-
trainer.fit_generator(generator=generator, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
239+
trainer.fit_generator(generator=generator, nb_epochs=2, validation_data=(x_train_mnist, y_train_mnist))
240240

241241

242242
@pytest.mark.only_with_platform("pytorch")
@@ -267,7 +267,7 @@ def test_adversarial_trainer_awptrades_pytorch_fit_generator_and_predict(
267267
else:
268268
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
269269

270-
trainer.fit_generator(generator=generator, nb_epochs=20)
270+
trainer.fit_generator(generator=generator, nb_epochs=5)
271271
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
272272

273273
if label_format == "one_hot":
@@ -284,4 +284,4 @@ def test_adversarial_trainer_awptrades_pytorch_fit_generator_and_predict(
284284
assert accuracy == 0.32
285285
assert accuracy_new > 0.32
286286

287-
trainer.fit_generator(generator=generator, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
287+
trainer.fit_generator(generator=generator, nb_epochs=2, validation_data=(x_train_mnist, y_train_mnist))

tests/defences/trainer/test_adversarial_trainer_oaat_pytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _get_adv_trainer_oaat():
5959
"alternate_iter_eps": 0.15,
6060
"swa_save_epoch": 0,
6161
"list_swa_epoch": [0, 15, 15],
62-
"max_iter": 20,
62+
"max_iter": 5,
6363
"models_path": None,
6464
"load_swa_model_tau": 0.995,
6565
"layer_names_activation": ["conv"],
@@ -120,7 +120,7 @@ def test_adversarial_trainer_oaat_pytorch_fit_and_predict(get_adv_trainer_oaat,
120120
else:
121121
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
122122

123-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, batch_size=16)
123+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=5, batch_size=16)
124124
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
125125

126126
if label_format == "one_hot":
@@ -138,7 +138,7 @@ def test_adversarial_trainer_oaat_pytorch_fit_and_predict(get_adv_trainer_oaat,
138138
assert accuracy_new > 0.32
139139

140140
trainer.fit(
141-
x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist), batch_size=16
141+
x_train_mnist, y_train_mnist, nb_epochs=2, validation_data=(x_train_mnist, y_train_mnist), batch_size=16
142142
)
143143

144144

@@ -170,7 +170,7 @@ def test_adversarial_trainer_oaat_pytorch_fit_generator_and_predict(
170170
else:
171171
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
172172

173-
trainer.fit_generator(generator=generator, nb_epochs=20)
173+
trainer.fit_generator(generator=generator, nb_epochs=10)
174174
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
175175

176176
if label_format == "one_hot":
@@ -187,4 +187,4 @@ def test_adversarial_trainer_oaat_pytorch_fit_generator_and_predict(
187187
assert accuracy == 0.32
188188
assert accuracy_new > 0.32
189189

190-
trainer.fit_generator(generator=generator, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
190+
trainer.fit_generator(generator=generator, nb_epochs=2, validation_data=(x_train_mnist, y_train_mnist))

tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_adv_trainer():
4040
norm=np.inf,
4141
eps=0.3,
4242
eps_step=0.03,
43-
max_iter=20,
43+
max_iter=5,
4444
targeted=False,
4545
num_random_init=1,
4646
batch_size=128,
@@ -89,7 +89,7 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
8989
else:
9090
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
9191

92-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
92+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=5)
9393
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
9494

9595
if label_format == "one_hot":
@@ -106,7 +106,7 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
106106
assert accuracy == 0.32
107107
assert accuracy_new > 0.32
108108

109-
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
109+
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=2, validation_data=(x_train_mnist, y_train_mnist))
110110

111111

112112
@pytest.mark.only_with_platform("pytorch")
@@ -136,7 +136,7 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
136136
else:
137137
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
138138

139-
trainer.fit_generator(generator=generator, nb_epochs=20)
139+
trainer.fit_generator(generator=generator, nb_epochs=5)
140140
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
141141

142142
if label_format == "one_hot":

0 commit comments

Comments
 (0)