Skip to content

Commit 2afa66a

Browse files
committed
Update typing and documentation of AdversarialPatchPyTorch for object detection
Signed-off-by: Beat Buesser <[email protected]>
1 parent 3f2dbef commit 2afa66a

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ def __init__(
183183
self._optimizer = torch.optim.Adam([self._patch], lr=self.learning_rate)
184184

185185
def _train_step(
186-
self, images: "torch.Tensor", target: "torch.Tensor", mask: "torch.Tensor" | None = None
186+
self,
187+
images: "torch.Tensor",
188+
target: "torch.Tensor" | list[dict[str, "torch.Tensor"]],
189+
mask: "torch.Tensor" | None = None,
187190
) -> "torch.Tensor":
188191
import torch
189192

@@ -227,7 +230,12 @@ def _predictions(
227230

228231
return predictions, target
229232

230-
def _loss(self, images: "torch.Tensor", target: "torch.Tensor", mask: "torch.Tensor" | None) -> "torch.Tensor":
233+
def _loss(
234+
self,
235+
images: "torch.Tensor",
236+
target: "torch.Tensor" | list[dict[str, "torch.Tensor"]],
237+
mask: "torch.Tensor" | None,
238+
) -> "torch.Tensor":
231239
import torch
232240

233241
if isinstance(target, torch.Tensor):
@@ -475,13 +483,13 @@ def _random_overlay(
475483
return patched_images
476484

477485
def generate( # type: ignore
478-
self, x: np.ndarray, y: np.ndarray | None = None, **kwargs
486+
self, x: np.ndarray, y: np.ndarray | list[dict[str, np.ndarray | "torch.Tensor"]] | None = None, **kwargs
479487
) -> tuple[np.ndarray, np.ndarray]:
480488
"""
481489
Generate an adversarial patch and return the patch and its mask in arrays.
482490
483491
:param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
484-
:param y: An array with the original true labels.
492+
:param y: The true or target labels.
485493
:param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
486494
(N, H, W) without their channel dimensions. Any features for which the mask is True can be the
487495
center location of the patch during sampling.
@@ -499,11 +507,12 @@ def generate( # type: ignore
499507
if self.patch_location is not None and mask is not None:
500508
raise ValueError("Masks can only be used if the `patch_location` is `None`.")
501509

502-
if y is None: # pragma: no cover
503-
logger.info("Setting labels to estimator predictions and running untargeted attack because `y=None`.")
504-
y = to_categorical(np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes)
505-
506510
if hasattr(self.estimator, "nb_classes"):
511+
512+
if y is None: # pragma: no cover
513+
logger.info("Setting labels to estimator classification predictions.")
514+
y = to_categorical(np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes)
515+
507516
y = check_and_transform_label_format(labels=y, nb_classes=self.estimator.nb_classes)
508517

509518
# check if logits or probabilities
@@ -513,6 +522,10 @@ def generate( # type: ignore
513522
self.use_logits = False
514523
else:
515524
self.use_logits = True
525+
else:
526+
if y is None: # pragma: no cover
527+
logger.info("Setting labels to estimator object detection predictions.")
528+
y = self.estimator.predict(x=x)
516529

517530
if isinstance(y, np.ndarray):
518531
x_tensor = torch.Tensor(x)

0 commit comments

Comments
 (0)