Skip to content

Commit 4ab321b

Browse files
author
Beat Buesser
committed
Update progress bars in PGDs
Signed-off-by: Beat Buesser <[email protected]>
1 parent 71bc9cd commit 4ab321b

File tree

5 files changed

+16
-12
lines changed

5 files changed

+16
-12
lines changed

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_numpy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import numpy as np
3232
from scipy.stats import truncnorm
33+
from tqdm import trange
3334

3435
from art.attacks.evasion.fast_gradient import FastGradientMethod
3536
from art.config import ART_NUMPY_DTYPE
@@ -266,10 +267,10 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
266267
adv_x_best = None
267268
rate_best = None
268269

269-
for _ in range(max(1, self.num_random_init)):
270+
for _ in trange(max(1, self.num_random_init), desc="PGD - Random Initializations"):
270271
adv_x = x.astype(ART_NUMPY_DTYPE)
271272

272-
for i_max_iter in range(self.max_iter):
273+
for i_max_iter in trange(self.max_iter, desc="PGD - Iterations", leave=False):
273274
adv_x = self._compute(
274275
adv_x,
275276
x,
@@ -313,7 +314,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
313314
# Start to compute adversarial examples
314315
adv_x = x.astype(ART_NUMPY_DTYPE)
315316

316-
for i_max_iter in range(self.max_iter):
317+
for i_max_iter in trange(self.max_iter, desc="PGD - Iterations"):
317318
adv_x = self._compute(
318319
adv_x,
319320
x,

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_pytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from typing import Optional, Union, TYPE_CHECKING
3030

3131
import numpy as np
32+
from tqdm import trange, tqdm
3233

3334
from art.config import ART_NUMPY_DTYPE
3435
from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent_numpy import (
@@ -163,11 +164,11 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
163164
adv_x_best = None
164165
rate_best = None
165166

166-
for _ in range(max(1, self.num_random_init)):
167+
for _ in trange(max(1, self.num_random_init), desc="PGD - Random Initializations"):
167168
adv_x = x.astype(ART_NUMPY_DTYPE)
168169

169170
# Compute perturbation with batching
170-
for (batch_id, batch_all) in enumerate(data_loader):
171+
for (batch_id, batch_all) in enumerate(tqdm(data_loader, desc="PGD - Iterations", leave=False)):
171172
if mask is not None:
172173
(batch, batch_labels, mask_batch) = batch_all[0], batch_all[1], batch_all[2]
173174
else:

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_tensorflow_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from typing import Optional, Union, TYPE_CHECKING
3030

3131
import numpy as np
32+
from tqdm import trange, tqdm
3233

3334
from art.config import ART_NUMPY_DTYPE
3435
from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent_numpy import (
@@ -158,12 +159,12 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
158159
adv_x_best = None
159160
rate_best = None
160161

161-
for _ in range(max(1, self.num_random_init)):
162+
for _ in trange(max(1, self.num_random_init), desc="PGD - Random Initializations"):
162163
adv_x = x.astype(ART_NUMPY_DTYPE)
163164
data_loader = iter(dataset)
164165

165166
# Compute perturbation with batching
166-
for (batch_id, batch_all) in enumerate(data_loader):
167+
for (batch_id, batch_all) in enumerate(tqdm(data_loader, desc="PGD - Iterations", leave=False)):
167168
if mask is not None:
168169
(batch, batch_labels, mask_batch) = batch_all[0], batch_all[1], batch_all[2]
169170
else:

art/defences/trainer/adversarial_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
120120
# Precompute adversarial samples for transferred attacks
121121
logged = False
122122
self._precomputed_adv_samples = []
123-
for attack in tqdm(self.attacks, desc="Precompute adv samples"):
123+
for attack in tqdm(self.attacks, desc="Precompute adversarial examples."):
124124
if "targeted" in attack.attack_params and attack.targeted: # type: ignore
125125
raise NotImplementedError("Adversarial training with targeted attacks is currently not implemented")
126126

art/defences/trainer/adversarial_trainer_fbf_pytorch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from typing import Optional, Tuple, Union, TYPE_CHECKING
2828

2929
import numpy as np
30+
from tqdm import trange
3031

3132
from art.config import ART_NUMPY_DTYPE
3233
from art.defences.trainer.adversarial_trainer_fbf import AdversarialTrainerFBF
@@ -92,9 +93,9 @@ def fit(
9293
def lr_schedule(t):
9394
return np.interp([t], [0, nb_epochs * 2 // 5, nb_epochs], [0, 0.21, 0])[0]
9495

95-
logger.info("Adversarial training FBF")
96+
logger.info("Adversarial Training FBF")
9697

97-
for i_epoch in range(nb_epochs):
98+
for i_epoch in trange(nb_epochs, desc="Adversarial Training FBF - Epochs"):
9899
# Shuffle the examples
99100
np.random.shuffle(ind)
100101
start_time = time.time()
@@ -160,9 +161,9 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
160161
def lr_schedule(t):
161162
return np.interp([t], [0, nb_epochs * 2 // 5, nb_epochs], [0, 0.21, 0])[0]
162163

163-
logger.info("Adversarial training FBF")
164+
logger.info("Adversarial Training FBF")
164165

165-
for i_epoch in range(nb_epochs):
166+
for i_epoch in trange(nb_epochs, desc="Adversarial Training FBF - Epochs"):
166167
start_time = time.time()
167168
train_loss = 0.0
168169
train_acc = 0.0

0 commit comments

Comments
 (0)