2323import logging
2424import math
2525import random
26- from typing import Dict , List , Optional , Tuple , TYPE_CHECKING
26+ from typing import Dict , List , Optional , Tuple , Union , TYPE_CHECKING
2727
2828import numpy as np
2929from tqdm import trange
@@ -81,12 +81,21 @@ def __init__(
8181 self ._patch = np .ones (shape = patch_shape ) * (self .estimator .clip_values [1 ] + self .estimator .clip_values [0 ]) / 2.0
8282 self ._check_params ()
8383
84- def generate (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
84+ self .target_label = []
85+
86+ def generate (
87+ self ,
88+ x : np .ndarray ,
89+ y : Optional [np .ndarray ] = None ,
90+ target_label : Union [int , List [int ], np .ndarray ] = 0 ,
91+ ** kwargs
92+ ) -> np .ndarray :
8593 """
8694 Generate DPatch.
8795
8896 :param x: Sample images.
8997 :param y: Target labels for object detector.
98+ :param target_label: The target label of the DPatch attack.
9099 :return: Adversarial patch.
91100 """
92101 channel_index = 1 if self .estimator .channels_first else x .ndim - 1
@@ -96,6 +105,16 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
96105 raise ValueError ("The DPatch attack does not use target labels." )
97106 if x .ndim != 4 :
98107 raise ValueError ("The adversarial patch can only be applied to images." )
108+ if isinstance (target_label , int ):
109+ self .target_label = [target_label ] * x .shape [0 ]
110+ elif isinstance (target_label , np .ndarray ):
111+ if not (target_label .shape == (x .shape [0 ], 1 ) or target_label .shape == (x .shape [0 ],)):
112+ raise ValueError ("The target_label has to be a 1-dimensional array." )
113+ self .target_label = target_label .tolist ()
114+ else :
115+ if not len (target_label == x .shape [0 ]) or not isinstance (target_label , list ):
116+ raise ValueError ("The target_label as list of integers needs to of length number of images in `x`." )
117+ self .target_label = target_label
99118
100119 for i_step in trange (self .max_iter , desc = "DPatch iteration" ):
101120 if i_step == 0 or (i_step + 1 ) % 100 == 0 :
@@ -115,7 +134,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
115134
116135 target_dict = dict ()
117136 target_dict ["boxes" ] = np .asarray ([[i_x_1 , i_y_1 , i_x_2 , i_y_2 ]])
118- target_dict ["labels" ] = np .asarray ([1 ,])
137+ target_dict ["labels" ] = np .asarray ([target_label [ i_image ] ,])
119138 target_dict ["scores" ] = np .asarray ([1.0 ,])
120139
121140 patch_target .append (target_dict )
0 commit comments