Skip to content

Commit 8bfb305

Browse files
authored
Merge pull request #2641 from stekunda/issue-547-api-inconsistency
Update AdversarialPatchPytorch untargeted attacks to throw a Value Error when the user provides labels
2 parents ce7d2d9 + 79efb61 commit 8bfb305

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,8 @@ def generate( # type: ignore
490490
491491
:param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
492492
:param y: True or target labels of format `list[dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each
493-
input image. The fields of the dict are as follows:
493+
input image. For untargeted attacks, if `y` is None the attack will attack the model predictions
494+
on `x`. The fields of the dict are as follows:
494495
495496
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
496497
- labels [N]: the labels for each image.
@@ -511,6 +512,9 @@ def generate( # type: ignore
511512
if self.patch_location is not None and mask is not None:
512513
raise ValueError("Masks can only be used if the `patch_location` is `None`.")
513514

515+
if y is None and self.targeted:
516+
raise ValueError("The targeted version of AdversarialPatch attack requires provided target labels.")
517+
514518
if hasattr(self.estimator, "nb_classes"):
515519

516520
y_array: np.ndarray

art/attacks/evasion/dpatch_robust.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def generate( # type: ignore
155155
if y is None and self.targeted:
156156
raise ValueError("The targeted version of RobustDPatch attack requires target labels provided to `y`.")
157157
if y is not None and not self.targeted:
158-
raise ValueError("The RobustDPatch attack does not use target labels.")
158+
raise ValueError("The untargeted version of RobustDPatch attack does not use True labels provided to 'y'.")
159159
if x.ndim != 4: # pragma: no cover
160160
raise ValueError("The adversarial patch can only be applied to images.")
161161

0 commit comments

Comments
 (0)