Skip to content

Commit 3378220

Browse files
committed
Update docstring and fix order of function calls
Signed-off-by: Beat Buesser <[email protected]>
1 parent 911884b commit 3378220

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,11 @@ def generate( # type: ignore
489489
Generate an adversarial patch and return the patch and its mask in arrays.
490490
491491
:param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
492-
:param y: The true or target labels.
492+
:param y: True or target labels of format `list[dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each
493+
input image. The fields of the dict are as follows:
494+
495+
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
496+
- labels [N]: the labels for each image.
493497
:param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
494498
(N, H, W) without their channel dimensions. Any features for which the mask is True can be the
495499
center location of the patch during sampling.
@@ -511,11 +515,13 @@ def generate( # type: ignore
511515

512516
if y is None: # pragma: no cover
513517
logger.info("Setting labels to estimator classification predictions.")
514-
y = to_categorical(np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes)
515-
518+
y_array: np.ndarray = to_categorical(
519+
np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes
520+
)
521+
else:
516522
y_array: np.ndarray = y
517523

518-
y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)
524+
y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)
519525

520526
# check if logits or probabilities
521527
y_pred = self.estimator.predict(x=x[[0]])

0 commit comments

Comments
 (0)