@@ -78,7 +78,7 @@ def __init__(
7878 batch_size : int = 32 ,
7979 estimator_orig : "CLASSIFIER_TYPE" | None = None ,
8080 targeted : bool = False ,
81- parallel : bool = False ,
81+ parallel_pool_size : int = 0 ,
8282 ):
8383 """
8484 Create a :class:`.AutoAttack` instance.
@@ -93,7 +93,8 @@ def __init__(
9393 :param estimator_orig: Original estimator to be attacked by adversarial examples.
9494 :param targeted: If False run only untargeted attacks, if True also run targeted attacks against each possible
9595 target.
96- :param parallel: If True run attacks in parallel.
96+ :param parallel_pool_size: Number of parallel threads / pool size in multiprocessing. If parallel_pool_size=0
97+ computation runs without multiprocessing.
9798 """
9899 super ().__init__ (estimator = estimator )
99100
@@ -151,7 +152,7 @@ def __init__(
151152 self .estimator_orig = estimator
152153
153154 self ._targeted = targeted
154- self .parallel = parallel
155+ self .parallel_pool_size = parallel_pool_size
155156 self .best_attacks : np .ndarray = np .array ([])
156157 self ._check_params ()
157158
@@ -199,7 +200,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
199200 if attack .targeted :
200201 attack .set_params (targeted = False )
201202
202- if self .parallel :
203+ if self .parallel_pool_size > 0 :
203204 args .append (
204205 (
205206 deepcopy (x_adv ),
@@ -253,7 +254,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
253254 targeted_labels [:, i ], nb_classes = self .estimator .nb_classes
254255 )
255256
256- if self .parallel :
257+ if self .parallel_pool_size > 0 :
257258 args .append (
258259 (
259260 deepcopy (x_adv ),
@@ -287,8 +288,8 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
287288 except ValueError as error :
288289 logger .warning ("Error completing attack: %s}" , str (error ))
289290
290- if self .parallel :
291- with multiprocess .get_context ("spawn" ).Pool () as pool :
291+ if self .parallel_pool_size > 0 :
292+ with multiprocess .get_context ("spawn" ).Pool (processes = self . parallel_pool_size ) as pool :
292293 # Results come back in the order that they were issued
293294 results = pool .starmap (run_attack , args )
294295 perturbations = []
@@ -320,15 +321,16 @@ def __repr__(self) -> str:
320321 This method returns a summary of the best performing (lowest perturbation in the parallel case) attacks
321322 per image passed to the AutoAttack class.
322323 """
323- if self .parallel :
324+ if self .parallel_pool_size > 0 :
324325 best_attack_meta = "\n " .join (
325326 [
326327 f"image { i + 1 } : { str (self .args [idx ][3 ])} " if idx != 0 else f"image { i + 1 } : n/a"
327328 for i , idx in enumerate (self .best_attacks )
328329 ]
329330 )
330331 auto_attack_meta = (
331- f"AutoAttack(targeted={ self .targeted } , parallel={ self .parallel } , num_attacks={ len (self .args )} )"
332+ f"AutoAttack(targeted={ self .targeted } , parallel_pool_size={ self .parallel_pool_size } , "
333+ + "num_attacks={len(self.args)})"
332334 )
333335 return f"{ auto_attack_meta } \n BestAttacks:\n { best_attack_meta } "
334336
@@ -339,7 +341,8 @@ def __repr__(self) -> str:
339341 ]
340342 )
341343 auto_attack_meta = (
342- f"AutoAttack(targeted={ self .targeted } , parallel={ self .parallel } , num_attacks={ len (self .attacks )} )"
344+ f"AutoAttack(targeted={ self .targeted } , parallel_pool_size={ self .parallel_pool_size } , "
345+ + "num_attacks={len(self.attacks)})"
343346 )
344347 return f"{ auto_attack_meta } \n BestAttacks:\n { best_attack_meta } "
345348
0 commit comments