Skip to content

Commit 3675a96

Browse files
committed
Updating AdversarialPatchPytorch to throw an error when user provides labels during an untargeted attack
1 parent bc54526 commit 3675a96

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def generate( # type: ignore
490490
491491
:param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
492492
:param y: True or target labels of format `list[dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each
493-
input image. The fields of the dict are as follows:
493+
input image. For untargeted attacks, this should be None. The fields of the dict are as follows:
494494
495495
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
496496
- labels [N]: the labels for each image.
@@ -515,6 +515,9 @@ def generate( # type: ignore
515515

516516
y_array: np.ndarray
517517

518+
if y is not None and not self.targeted:
519+
raise ValueError("The untargeted version of AdversarialPatch attack does not use target labels.")
520+
518521
if y is None: # pragma: no cover
519522
logger.info("Setting labels to estimator classification predictions.")
520523
y_array = to_categorical(
@@ -533,6 +536,9 @@ def generate( # type: ignore
533536
else:
534537
self.use_logits = True
535538
else:
539+
if y is not None and not self.targeted:
540+
raise ValueError("The untargeted version of AdversarialPatch attack does not use target labels.")
541+
536542
if y is None: # pragma: no cover
537543
logger.info("Setting labels to estimator object detection predictions.")
538544
y = self.estimator.predict(x=x)

0 commit comments

Comments
 (0)