Skip to content

Commit 5a0fe10

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Fix doc & PEP8 for AdversarialTrainer
1 parent 82f4204 commit 5a0fe10

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

art/defences/adversarial_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class AdversarialTrainer:
1515
training on all adversarial data and other common setups. If multiple attacks are specified, they are rotated
1616
for each batch. If the specified attacks have as target a different model, then the attack is transferred. The
1717
`ratio` determines how many of the clean samples in each batch are replaced with their adversarial counterpart.
18-
When the attack targets the current classifier, only successful adversarial samples are used.
18+
19+
.. warning:: Both successful and unsuccessful adversarial samples are used for training. In the case of
20+
unbounded attacks (e.g., DeepFool), this can result in invalid (very noisy) samples being included.
1921
"""
2022
def __init__(self, classifier, attacks, ratio=.5):
2123
"""
@@ -71,7 +73,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
7173
self._precomputed_adv_samples = []
7274
for attack in self.attacks:
7375
if 'targeted' in attack.attack_params:
74-
if attack.targeted:
76+
if attack.targeted:
7577
raise NotImplementedError("Adversarial training with targeted attacks is \
7678
currently not implemented")
7779

@@ -153,7 +155,7 @@ def fit(self, x, y, **kwargs):
153155

154156
# Generate adversarial samples for each attack
155157
for i, attack in enumerate(self.attacks):
156-
if 'targeted' in attack.attack_params and attack.targeted:
158+
if 'targeted' in attack.attack_params and attack.targeted:
157159
raise NotImplementedError("Adversarial training with targeted attacks is \
158160
currently not implemented")
159161

art/defences/adversarial_trainer_unittest.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,15 @@ def test_two_attacks(self):
169169

170170
logger.info('Accuracy before adversarial training: %.2f%%', (acc * 100))
171171
logger.info('\nAccuracy after adversarial training: %.2f%%', (acc_new * 100))
172-
173-
172+
174173
def test_targeted_attack_error(self):
175174
"""
176175
Test the adversarial trainer using a targeted attack, which will currently result in a
177176
NotImplementError.
178177
179178
:return: None
180179
"""
181-
182-
(x_train, y_train), (x_test, y_test) = self.mnist
180+
(x_train, y_train), (_, _) = self.mnist
183181
params = {'nb_epochs': 2, 'batch_size': BATCH_SIZE}
184182

185183
classifier = self.classifier_k
@@ -294,14 +292,14 @@ def test_targeted_attack_error(self):
294292
295293
:return: None
296294
"""
297-
298-
(x_train, y_train), (x_test, y_test) = self.mnist
295+
(x_train, y_train), (_, _) = self.mnist
299296
params = {'nb_epochs': 2, 'batch_size': BATCH_SIZE}
300297

301298
classifier = self.classifier_k
302299
adv = FastGradientMethod(classifier, targeted=True)
303300
adv_trainer = StaticAdversarialTrainer(classifier, attacks=adv)
304301
self.assertRaises(NotImplementedError, adv_trainer.fit, x_train, y_train, **params)
305302

303+
306304
if __name__ == '__main__':
307305
unittest.main()

0 commit comments

Comments
 (0)