@@ -70,11 +70,16 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
7070 logged = False
7171 self ._precomputed_adv_samples = []
7272 for attack in self .attacks :
73+ if 'targeted' in attack .attack_params :
74+ if attack .targeted :
75+ raise NotImplementedError ("Adversarial training with targeted attacks is \
76+ currently not implemented" )
77+
7378 if attack .classifier != self .classifier :
7479 if not logged :
7580 logger .info ('Precomputing transferred adversarial samples.' )
7681 logged = True
77- self ._precomputed_adv_samples .append (attack .generate (x ))
82+ self ._precomputed_adv_samples .append (attack .generate (x , y = y ))
7883 else :
7984 self ._precomputed_adv_samples .append (None )
8085
@@ -86,7 +91,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
8691
8792 for batch_id in range (nb_batches ):
8893 # Create batch data
89- x_batch = x [ind [batch_id * batch_size :min ((batch_id + 1 ) * batch_size , x .shape [0 ])]]
94+ x_batch = x [ind [batch_id * batch_size :min ((batch_id + 1 ) * batch_size , x .shape [0 ])]]. copy ()
9095 y_batch = y [ind [batch_id * batch_size :min ((batch_id + 1 ) * batch_size , x .shape [0 ])]]
9196
9297 # Choose indices to replace with adversarial samples
@@ -95,12 +100,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
95100
96101 # If source and target models are the same, craft fresh adversarial samples
97102 if attack .classifier == self .classifier :
98- labels_batch = np .argmax (y_batch , axis = 1 )
99- x_adv = attack .generate (x_batch [adv_ids ])
100- y_adv = np .argmax (attack .classifier .predict (x_adv ), axis = 1 )
101- selected = np .array (y_adv != labels_batch [adv_ids ])
102-
103- x_batch [adv_ids ][selected ] = x_adv [selected ]
103+ x_batch [adv_ids ] = attack .generate (x_batch [adv_ids ], y = y_batch [adv_ids ])
104104
105105 # Otherwise, use precomputed adversarial samples
106106 else :
@@ -153,9 +153,13 @@ def fit(self, x, y, **kwargs):
153153
154154 # Generate adversarial samples for each attack
155155 for i , attack in enumerate (self .attacks ):
156+ if 'targeted' in attack .attack_params and attack .targeted :
157+ raise NotImplementedError ("Adversarial training with targeted attacks is \
158+ currently not implemented" )
159+
156160 logger .info ('Generating adversarial samples from attack: %i/%i.' , i , len (self .attacks ))
157161 # Predict new labels for the adversarial samples generated
158- x_adv = attack .generate (x )
162+ x_adv = attack .generate (x , y = y )
159163 y_pred = np .argmax (attack .classifier .predict (x_adv ), axis = 1 )
160164 selected = np .array (labels != y_pred )
161165 logger .info ('%i successful samples generated.' , len (selected ))
0 commit comments