2727from typing import Optional , Tuple , TYPE_CHECKING
2828
2929import numpy as np
30- from tqdm import tqdm
30+ from tqdm import tqdm , trange
3131
3232from art .attacks .attack import EvasionAttack
3333from art .config import ART_NUMPY_DTYPE
@@ -75,6 +75,7 @@ def __init__(
7575 num_trial : int = 25 ,
7676 sample_size : int = 20 ,
7777 init_size : int = 100 ,
78+ min_epsilon : Optional [float ] = None ,
7879 verbose : bool = True ,
7980 ) -> None :
8081 """
@@ -89,6 +90,7 @@ def __init__(
8990 :param num_trial: Maximum number of trials per iteration.
9091 :param sample_size: Number of samples per trial.
9192 :param init_size: Maximum number of trials for initial generation of adversarial examples.
93+ :param min_epsilon: Stop attack if perturbation is smaller than `min_epsilon`.
9294 :param verbose: Show progress bars.
9395 """
9496 super ().__init__ (estimator = estimator )
@@ -101,10 +103,13 @@ def __init__(
101103 self .num_trial = num_trial
102104 self .sample_size = sample_size
103105 self .init_size = init_size
106+ self .min_epsilon = min_epsilon
104107 self .batch_size = 1
105108 self .verbose = verbose
106109 self ._check_params ()
107110
111+ self .curr_adv = None
112+
108113 def generate (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
109114 """
110115 Generate adversarial samples and return them in an array.
@@ -230,8 +235,10 @@ def _attack(
230235 self .curr_delta = initial_delta
231236 self .curr_epsilon = initial_epsilon
232237
238+ self .curr_adv = x_adv
239+
233240 # Main loop to wander around the boundary
234- for _ in range (self .max_iter ):
241+ for _ in trange (self .max_iter , desc = "Boundary attack - iterations" , disable = not self . verbose ):
235242 # Trust region method to adjust delta
236243 for _ in range (self .num_trial ):
237244 potential_advs = []
@@ -273,11 +280,15 @@ def _attack(
273280
274281 if epsilon_ratio > 0 :
275282 x_adv = potential_advs [np .where (satisfied )[0 ][0 ]]
283+ self .curr_adv = x_adv
276284 break
277285 else :
278286 logger .warning ("Adversarial example found but not optimal." )
279287 return x_advs [0 ]
280288
289+ if self .min_epsilon is not None and self .curr_epsilon < self .min_epsilon :
290+ return x_adv
291+
281292 return x_adv
282293
283294 def _orthogonal_perturb (self , delta : float , current_sample : np .ndarray , original_sample : np .ndarray ) -> np .ndarray :
@@ -299,20 +310,13 @@ def _orthogonal_perturb(self, delta: float, current_sample: np.ndarray, original
299310 # Project the perturbation onto sphere
300311 direction = original_sample - current_sample
301312
302- if len (self .estimator .input_shape ) == 3 :
303- channel_index = 1 if self .estimator .channels_first else 3
304- perturb = np .swapaxes (perturb , 0 , channel_index - 1 )
305- direction = np .swapaxes (direction , 0 , channel_index - 1 )
306- for i in range (direction .shape [0 ]):
307- direction [i ] /= np .linalg .norm (direction [i ])
308- perturb [i ] -= np .dot (np .dot (perturb [i ], direction [i ].T ), direction [i ])
309-
310- perturb = np .swapaxes (perturb , 0 , channel_index - 1 )
311- elif len (self .estimator .input_shape ) == 1 :
312- direction /= np .linalg .norm (direction )
313- perturb -= np .dot (perturb , direction .T ) * direction
314- else :
315- raise ValueError ("Input shape not recognised." )
313+ direction_flat = direction .flatten ()
314+ perturb_flat = perturb .flatten ()
315+
316+ direction_flat /= np .linalg .norm (direction_flat )
317+ perturb_flat -= np .dot (perturb_flat , direction_flat .T ) * direction_flat
318+ perturb = perturb_flat .reshape (self .estimator .input_shape )
319+
316320 hypotenuse = np .sqrt (1 + delta ** 2 )
317321 perturb = ((1 - hypotenuse ) * (current_sample - original_sample ) + perturb ) / hypotenuse
318322 return perturb
@@ -403,5 +407,8 @@ def _check_params(self) -> None:
403407 if self .step_adapt <= 0 or self .step_adapt >= 1 :
404408 raise ValueError ("The adaptation factor must be in the range (0, 1)." )
405409
410+ if self .min_epsilon is not None and (isinstance (self .min_epsilon , float ) or self .min_epsilon <= 0 ):
411+ raise ValueError ("The minimum epsilon must be a positive float." )
412+
406413 if not isinstance (self .verbose , bool ):
407414 raise ValueError ("The argument `verbose` has to be of type bool." )
0 commit comments