Skip to content

Commit 82f4204

Browse files
MARIA NICOLAEGitHub Enterprise
authored andcommitted
Merge pull request #85 from MATHSINN/fix-adv-trainer
Copy batches and also use unsuccessful attack samples in `AdversarialTrainer`
2 parents 017c711 + d0b976f commit 82f4204

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

art/defences/adversarial_trainer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,16 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
7070
logged = False
7171
self._precomputed_adv_samples = []
7272
for attack in self.attacks:
73+
if 'targeted' in attack.attack_params:
74+
if attack.targeted:
75+
raise NotImplementedError("Adversarial training with targeted attacks is \
76+
currently not implemented")
77+
7378
if attack.classifier != self.classifier:
7479
if not logged:
7580
logger.info('Precomputing transferred adversarial samples.')
7681
logged = True
77-
self._precomputed_adv_samples.append(attack.generate(x))
82+
self._precomputed_adv_samples.append(attack.generate(x, y=y))
7883
else:
7984
self._precomputed_adv_samples.append(None)
8085

@@ -86,7 +91,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
8691

8792
for batch_id in range(nb_batches):
8893
# Create batch data
89-
x_batch = x[ind[batch_id * batch_size:min((batch_id + 1) * batch_size, x.shape[0])]]
94+
x_batch = x[ind[batch_id * batch_size:min((batch_id + 1) * batch_size, x.shape[0])]].copy()
9095
y_batch = y[ind[batch_id * batch_size:min((batch_id + 1) * batch_size, x.shape[0])]]
9196

9297
# Choose indices to replace with adversarial samples
@@ -95,12 +100,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
95100

96101
# If source and target models are the same, craft fresh adversarial samples
97102
if attack.classifier == self.classifier:
98-
labels_batch = np.argmax(y_batch, axis=1)
99-
x_adv = attack.generate(x_batch[adv_ids])
100-
y_adv = np.argmax(attack.classifier.predict(x_adv), axis=1)
101-
selected = np.array(y_adv != labels_batch[adv_ids])
102-
103-
x_batch[adv_ids][selected] = x_adv[selected]
103+
x_batch[adv_ids] = attack.generate(x_batch[adv_ids], y=y_batch[adv_ids])
104104

105105
# Otherwise, use precomputed adversarial samples
106106
else:
@@ -153,9 +153,13 @@ def fit(self, x, y, **kwargs):
153153

154154
# Generate adversarial samples for each attack
155155
for i, attack in enumerate(self.attacks):
156+
if 'targeted' in attack.attack_params and attack.targeted:
157+
raise NotImplementedError("Adversarial training with targeted attacks is \
158+
currently not implemented")
159+
156160
logger.info('Generating adversarial samples from attack: %i/%i.', i, len(self.attacks))
157161
# Predict new labels for the adversarial samples generated
158-
x_adv = attack.generate(x)
162+
x_adv = attack.generate(x, y=y)
159163
y_pred = np.argmax(attack.classifier.predict(x_adv), axis=1)
160164
selected = np.array(labels != y_pred)
161165
logger.info('%i successful samples generated.', len(selected))

art/defences/adversarial_trainer_unittest.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,23 @@ def test_two_attacks(self):
169169

170170
logger.info('Accuracy before adversarial training: %.2f%%', (acc * 100))
171171
logger.info('\nAccuracy after adversarial training: %.2f%%', (acc_new * 100))
172+
173+
174+
def test_targeted_attack_error(self):
175+
"""
176+
Test the adversarial trainer using a targeted attack, which will currently result in a
177+
NotImplementError.
178+
179+
:return: None
180+
"""
181+
182+
(x_train, y_train), (x_test, y_test) = self.mnist
183+
params = {'nb_epochs': 2, 'batch_size': BATCH_SIZE}
184+
185+
classifier = self.classifier_k
186+
adv = FastGradientMethod(classifier, targeted=True)
187+
adv_trainer = AdversarialTrainer(classifier, attacks=adv)
188+
self.assertRaises(NotImplementedError, adv_trainer.fit, x_train, y_train, **params)
172189

173190

174191
class TestStaticAdversarialTrainer(TestBase):
@@ -270,6 +287,21 @@ def test_shared_model_mnist(self):
270287
logger.info('Accuracy before adversarial training: %.2f%%', (acc * 100))
271288
logger.info('Accuracy after adversarial training: %.2f%%', (acc_adv_trained * 100))
272289

290+
def test_targeted_attack_error(self):
291+
"""
292+
Test the adversarial trainer using a targeted attack, which will currently result in a
293+
NotImplementError.
294+
295+
:return: None
296+
"""
297+
298+
(x_train, y_train), (x_test, y_test) = self.mnist
299+
params = {'nb_epochs': 2, 'batch_size': BATCH_SIZE}
300+
301+
classifier = self.classifier_k
302+
adv = FastGradientMethod(classifier, targeted=True)
303+
adv_trainer = StaticAdversarialTrainer(classifier, attacks=adv)
304+
self.assertRaises(NotImplementedError, adv_trainer.fit, x_train, y_train, **params)
273305

274306
if __name__ == '__main__':
275307
unittest.main()

0 commit comments

Comments
 (0)