@@ -81,6 +81,7 @@ def generate(self, x, **kwargs):
8181 # Instantiate the middle attacker and get the predicted labels
8282 attacker = self ._get_attack (self .attacker , self .attacker_params )
8383 pred_y = self .classifier .predict (x , logits = False )
84+ pred_y_max = np .argmax (pred_y , axis = 1 )
8485
8586 # Start to generate the adversarial examples
8687 nb_iter = 0
@@ -92,14 +93,14 @@ def generate(self, x, **kwargs):
9293 for j , ex in enumerate (x [rnd_idx ]):
9394 xi = ex [None , ...]
9495
95- f_xi = self .classifier .predict (xi + v , logits = False )
96+ f_xi = self .classifier .predict (xi + v , logits = True )
9697 fk_i_hat = np .argmax (f_xi [0 ])
9798 fk_hat = np .argmax (pred_y [rnd_idx ][j ])
9899
99100 if fk_i_hat == fk_hat :
100101 # Compute adversarial perturbation
101102 adv_xi = attacker .generate (xi + v )
102- adv_f_xi = self .classifier .predict (adv_xi , logits = False )
103+ adv_f_xi = self .classifier .predict (adv_xi , logits = True )
103104 adv_fk_i_hat = np .argmax (adv_f_xi [0 ])
104105
105106 # If the class has changed, update v
@@ -112,10 +113,8 @@ def generate(self, x, **kwargs):
112113
113114 # Compute the error rate
114115 adv_x = x + v
115- adv_y = self .classifier .predict (adv_x , logits = False )
116- adv_y_max = np .argmax (adv_y , axis = 1 )
117- pred_y_max = np .argmax (pred_y , axis = 1 )
118- fooling_rate = np .sum (pred_y_max != adv_y_max ) / float (nb_instances )
116+ adv_y = np .argmax (self .classifier .predict (adv_x , logits = False ))
117+ fooling_rate = np .sum (pred_y_max != adv_y ) / nb_instances
119118
120119 self .fooling_rate = fooling_rate
121120 self .converged = (nb_iter < self .max_iter )
@@ -213,4 +212,3 @@ def _get_class(self, class_name):
213212 class_module = getattr (module , sub_mods [- 1 ])
214213
215214 return class_module
216-
0 commit comments