@@ -64,8 +64,13 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
6464 # Precompute adversarial samples for transferred attacks
6565 self ._precomputed_adv_samples = []
6666 for attack in self .attacks :
67+ if 'targeted' in attack .attack_params :
68+ if attack .targeted :
69+ raise NotImplementedError ("Adversarial training with targeted attacks is \
70+ currently not implemented" )
71+
6772 if attack .classifier != self .classifier :
68- self ._precomputed_adv_samples .append (attack .generate (x ))
73+ self ._precomputed_adv_samples .append (attack .generate (x , y = y ))
6974 else :
7075 self ._precomputed_adv_samples .append (None )
7176
@@ -84,7 +89,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
8489
8590 # If source and target models are the same, craft fresh adversarial samples
8691 if attack .classifier == self .classifier :
87- x_batch [adv_ids ] = attack .generate (x_batch [adv_ids ])
92+ x_batch [adv_ids ] = attack .generate (x_batch [adv_ids ], y = y_batch [ adv_ids ] )
8893
8994 # Otherwise, use precomputed adversarial samples
9095 else :
@@ -137,8 +142,13 @@ def fit(self, x, y, **kwargs):
137142
138143 # Generate adversarial samples for each attack
139144 for attack in self .attacks :
145+ if 'targeted' in attack .attack_params :
146+ if attack .targeted :
147+ raise NotImplementedError ("Adversarial training with targeted attacks is \
148+ currently not implemented" )
149+
140150 # Predict new labels for the adversarial samples generated
141- x_adv = attack .generate (x )
151+ x_adv = attack .generate (x , y = y )
142152 y_pred = np .argmax (attack .classifier .predict (x_adv ), axis = 1 )
143153 selected = np .array (labels != y_pred )
144154
0 commit comments