Skip to content

Commit bf196eb

Browse files
authored
Merge pull request #2534 from Trusted-AI/development_issue_2529
Add option for pool size in AutoAttack
2 parents 21f1923 + 7903726 commit bf196eb

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

art/attacks/evasion/auto_attack.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
batch_size: int = 32,
7979
estimator_orig: "CLASSIFIER_TYPE" | None = None,
8080
targeted: bool = False,
81-
parallel: bool = False,
81+
parallel_pool_size: int = 0,
8282
):
8383
"""
8484
Create a :class:`.AutoAttack` instance.
@@ -93,7 +93,8 @@ def __init__(
9393
:param estimator_orig: Original estimator to be attacked by adversarial examples.
9494
:param targeted: If False run only untargeted attacks, if True also run targeted attacks against each possible
9595
target.
96-
:param parallel: If True run attacks in parallel.
96+
:param parallel_pool_size: Number of parallel threads / pool size in multiprocessing. If parallel_pool_size=0
97+
computation runs without multiprocessing.
9798
"""
9899
super().__init__(estimator=estimator)
99100

@@ -151,7 +152,7 @@ def __init__(
151152
self.estimator_orig = estimator
152153

153154
self._targeted = targeted
154-
self.parallel = parallel
155+
self.parallel_pool_size = parallel_pool_size
155156
self.best_attacks: np.ndarray = np.array([])
156157
self._check_params()
157158

@@ -199,7 +200,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
199200
if attack.targeted:
200201
attack.set_params(targeted=False)
201202

202-
if self.parallel:
203+
if self.parallel_pool_size > 0:
203204
args.append(
204205
(
205206
deepcopy(x_adv),
@@ -253,7 +254,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
253254
targeted_labels[:, i], nb_classes=self.estimator.nb_classes
254255
)
255256

256-
if self.parallel:
257+
if self.parallel_pool_size > 0:
257258
args.append(
258259
(
259260
deepcopy(x_adv),
@@ -287,8 +288,8 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
287288
except ValueError as error:
288289
logger.warning("Error completing attack: %s}", str(error))
289290

290-
if self.parallel:
291-
with multiprocess.get_context("spawn").Pool() as pool:
291+
if self.parallel_pool_size > 0:
292+
with multiprocess.get_context("spawn").Pool(processes=self.parallel_pool_size) as pool:
292293
# Results come back in the order that they were issued
293294
results = pool.starmap(run_attack, args)
294295
perturbations = []
@@ -320,15 +321,16 @@ def __repr__(self) -> str:
320321
This method returns a summary of the best performing (lowest perturbation in the parallel case) attacks
321322
per image passed to the AutoAttack class.
322323
"""
323-
if self.parallel:
324+
if self.parallel_pool_size > 0:
324325
best_attack_meta = "\n".join(
325326
[
326327
f"image {i+1}: {str(self.args[idx][3])}" if idx != 0 else f"image {i+1}: n/a"
327328
for i, idx in enumerate(self.best_attacks)
328329
]
329330
)
330331
auto_attack_meta = (
331-
f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.args)})"
332+
f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, "
333+
+ "num_attacks={len(self.args)})"
332334
)
333335
return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}"
334336

@@ -339,7 +341,8 @@ def __repr__(self) -> str:
339341
]
340342
)
341343
auto_attack_meta = (
342-
f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.attacks)})"
344+
f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, "
345+
+ "num_attacks={len(self.attacks)})"
343346
)
344347
return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}"
345348

tests/attacks/evasion/test_auto_attack.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
273273
batch_size=batch_size,
274274
estimator_orig=None,
275275
targeted=False,
276-
parallel=True,
276+
parallel_pool_size=3,
277277
)
278278

279279
attack_noparallel = AutoAttack(
@@ -285,7 +285,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
285285
batch_size=batch_size,
286286
estimator_orig=None,
287287
targeted=False,
288-
parallel=False,
288+
parallel_pool_size=0,
289289
)
290290

291291
x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)
@@ -310,7 +310,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
310310
batch_size=batch_size,
311311
estimator_orig=None,
312312
targeted=True,
313-
parallel=True,
313+
parallel_pool_size=3,
314314
)
315315

316316
x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)

0 commit comments

Comments
 (0)