@@ -78,7 +78,7 @@ def __init__(
78
78
batch_size : int = 32 ,
79
79
estimator_orig : "CLASSIFIER_TYPE" | None = None ,
80
80
targeted : bool = False ,
81
- parallel : bool = False ,
81
+ parallel_pool_size : int = 0 ,
82
82
):
83
83
"""
84
84
Create a :class:`.AutoAttack` instance.
@@ -93,7 +93,8 @@ def __init__(
93
93
:param estimator_orig: Original estimator to be attacked by adversarial examples.
94
94
:param targeted: If False run only untargeted attacks, if True also run targeted attacks against each possible
95
95
target.
96
- :param parallel: If True run attacks in parallel.
96
+ :param parallel_pool_size: Number of parallel threads / pool size in multiprocessing. If parallel_pool_size=0
97
+ computation runs without multiprocessing.
97
98
"""
98
99
super ().__init__ (estimator = estimator )
99
100
@@ -151,7 +152,7 @@ def __init__(
151
152
self .estimator_orig = estimator
152
153
153
154
self ._targeted = targeted
154
- self .parallel = parallel
155
+ self .parallel_pool_size = parallel_pool_size
155
156
self .best_attacks : np .ndarray = np .array ([])
156
157
self ._check_params ()
157
158
@@ -199,7 +200,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
199
200
if attack .targeted :
200
201
attack .set_params (targeted = False )
201
202
202
- if self .parallel :
203
+ if self .parallel_pool_size > 0 :
203
204
args .append (
204
205
(
205
206
deepcopy (x_adv ),
@@ -253,7 +254,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
253
254
targeted_labels [:, i ], nb_classes = self .estimator .nb_classes
254
255
)
255
256
256
- if self .parallel :
257
+ if self .parallel_pool_size > 0 :
257
258
args .append (
258
259
(
259
260
deepcopy (x_adv ),
@@ -287,8 +288,8 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
287
288
except ValueError as error :
288
289
logger .warning ("Error completing attack: %s}" , str (error ))
289
290
290
- if self .parallel :
291
- with multiprocess .get_context ("spawn" ).Pool () as pool :
291
+ if self .parallel_pool_size > 0 :
292
+ with multiprocess .get_context ("spawn" ).Pool (processes = self . parallel_pool_size ) as pool :
292
293
# Results come back in the order that they were issued
293
294
results = pool .starmap (run_attack , args )
294
295
perturbations = []
@@ -320,15 +321,16 @@ def __repr__(self) -> str:
320
321
This method returns a summary of the best performing (lowest perturbation in the parallel case) attacks
321
322
per image passed to the AutoAttack class.
322
323
"""
323
- if self .parallel :
324
+ if self .parallel_pool_size > 0 :
324
325
best_attack_meta = "\n " .join (
325
326
[
326
327
f"image { i + 1 } : { str (self .args [idx ][3 ])} " if idx != 0 else f"image { i + 1 } : n/a"
327
328
for i , idx in enumerate (self .best_attacks )
328
329
]
329
330
)
330
331
auto_attack_meta = (
331
- f"AutoAttack(targeted={ self .targeted } , parallel={ self .parallel } , num_attacks={ len (self .args )} )"
332
+ f"AutoAttack(targeted={ self .targeted } , parallel_pool_size={ self .parallel_pool_size } , "
333
+ + "num_attacks={len(self.args)})"
332
334
)
333
335
return f"{ auto_attack_meta } \n BestAttacks:\n { best_attack_meta } "
334
336
@@ -339,7 +341,8 @@ def __repr__(self) -> str:
339
341
]
340
342
)
341
343
auto_attack_meta = (
342
- f"AutoAttack(targeted={ self .targeted } , parallel={ self .parallel } , num_attacks={ len (self .attacks )} )"
344
+ f"AutoAttack(targeted={ self .targeted } , parallel_pool_size={ self .parallel_pool_size } , "
345
+ + "num_attacks={len(self.attacks)})"
343
346
)
344
347
return f"{ auto_attack_meta } \n BestAttacks:\n { best_attack_meta } "
345
348
0 commit comments