Skip to content

Commit eac78e3

Browse files
committed
catch targeted attacks
1 parent 78890a4 commit eac78e3

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

art/defences/adversarial_trainer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
6464
# Precompute adversarial samples for transferred attacks
6565
self._precomputed_adv_samples = []
6666
for attack in self.attacks:
67+
if 'targeted' in attack.attack_params:
68+
if attack.targeted:
69+
raise NotImplementedError("Adversarial training with targeted attacks is \
70+
currently not implemented")
71+
6772
if attack.classifier != self.classifier:
68-
self._precomputed_adv_samples.append(attack.generate(x))
73+
self._precomputed_adv_samples.append(attack.generate(x, y=y))
6974
else:
7075
self._precomputed_adv_samples.append(None)
7176

@@ -84,7 +89,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
8489

8590
# If source and target models are the same, craft fresh adversarial samples
8691
if attack.classifier == self.classifier:
87-
x_batch[adv_ids] = attack.generate(x_batch[adv_ids])
92+
x_batch[adv_ids] = attack.generate(x_batch[adv_ids], y=y_batch[adv_ids])
8893

8994
# Otherwise, use precomputed adversarial samples
9095
else:
@@ -137,8 +142,13 @@ def fit(self, x, y, **kwargs):
137142

138143
# Generate adversarial samples for each attack
139144
for attack in self.attacks:
145+
if 'targeted' in attack.attack_params:
146+
if attack.targeted:
147+
raise NotImplementedError("Adversarial training with targeted attacks is \
148+
currently not implemented")
149+
140150
# Predict new labels for the adversarial samples generated
141-
x_adv = attack.generate(x)
151+
x_adv = attack.generate(x, y=y)
142152
y_pred = np.argmax(attack.classifier.predict(x_adv), axis=1)
143153
selected = np.array(labels != y_pred)
144154

art/defences/adversarial_trainer_unittest.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,24 @@ def test_two_attacks(self):
165165

166166
print('\nAccuracy before adversarial training: %.2f%%' % (acc * 100))
167167
print('\nAccuracy after adversarial training: %.2f%%' % (acc_new * 100))
168+
169+
170+
def test_targeted_attack_error(self):
171+
"""
172+
Test the adversarial trainer using a targeted attack, which will currently result in a
173+
NotImplementError.
174+
175+
:return: None
176+
"""
177+
178+
(x_train, y_train), (x_test, y_test) = self.mnist
179+
params = {'nb_epochs': 2, 'batch_size': BATCH_SIZE}
180+
181+
classifier = self.classifier_k
182+
adv = FastGradientMethod(classifier, targeted=True)
183+
adv_trainer = AdversarialTrainer(classifier, attacks=adv)
184+
self.assertRaises(NotImplementedError, adv_trainer.fit, x_train, y_train, **params)
185+
168186

169187

170188
class TestStaticAdversarialTrainer(TestBase):
@@ -266,6 +284,21 @@ def test_shared_model_mnist(self):
266284
print('\nAccuracy before adversarial training: %.2f%%' % (acc * 100))
267285
print('\nAccuracy after adversarial training: %.2f%%' % (acc_adv_trained * 100))
268286

287+
def test_targeted_attack_error(self):
288+
"""
289+
Test the adversarial trainer using a targeted attack, which will currently result in a
290+
NotImplementError.
291+
292+
:return: None
293+
"""
294+
295+
(x_train, y_train), (x_test, y_test) = self.mnist
296+
params = {'nb_epochs': 2, 'batch_size': BATCH_SIZE}
297+
298+
classifier = self.classifier_k
299+
adv = FastGradientMethod(classifier, targeted=True)
300+
adv_trainer = StaticAdversarialTrainer(classifier, attacks=adv)
301+
self.assertRaises(NotImplementedError, adv_trainer.fit, x_train, y_train, **params)
269302

270303
if __name__ == '__main__':
271304
unittest.main()

0 commit comments

Comments
 (0)