@@ -489,7 +489,11 @@ def generate( # type: ignore
489
489
Generate an adversarial patch and return the patch and its mask in arrays.
490
490
491
491
:param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
492
- :param y: The true or target labels.
492
+ :param y: True or target labels of format `list[dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each
493
+ input image. The fields of the dict are as follows:
494
+
495
+ - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
496
+ - labels [N]: the labels for each image.
493
497
:param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
494
498
(N, H, W) without their channel dimensions. Any features for which the mask is True can be the
495
499
center location of the patch during sampling.
@@ -511,11 +515,13 @@ def generate( # type: ignore
511
515
512
516
if y is None : # pragma: no cover
513
517
logger .info ("Setting labels to estimator classification predictions." )
514
- y = to_categorical (np .argmax (self .estimator .predict (x = x ), axis = 1 ), nb_classes = self .estimator .nb_classes )
515
-
518
+ y_array : np .ndarray = to_categorical (
519
+ np .argmax (self .estimator .predict (x = x ), axis = 1 ), nb_classes = self .estimator .nb_classes
520
+ )
521
+ else :
516
522
y_array : np .ndarray = y
517
523
518
- y = check_and_transform_label_format (labels = y_array , nb_classes = self .estimator .nb_classes )
524
+ y = check_and_transform_label_format (labels = y_array , nb_classes = self .estimator .nb_classes )
519
525
520
526
# check if logits or probabilities
521
527
y_pred = self .estimator .predict (x = x [[0 ]])
0 commit comments