@@ -107,7 +107,12 @@ def generate( # pylint: disable=W0221
107107 Generate DPatch.
108108
109109 :param x: Sample images.
110- :param y: Target labels for object detector.
110+ :param y: True labels of type `List[Dict[np.ndarray]]` for untargeted attack, one dictionary per input image.
111+ The keys and values of the dictionary are:
112+
113+ - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
114+ - labels [N]: the labels for each image
115+ - scores [N]: the scores or each prediction.
111116 :param target_label: The target label of the DPatch attack.
112117 :param mask: An boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
113118 (N, H, W) without their channel dimensions. Any features for which the mask is True can be the
@@ -134,8 +139,6 @@ def generate( # pylint: disable=W0221
134139 channel_index = 1 if self .estimator .channels_first else x .ndim - 1
135140 if x .shape [channel_index ] != self .patch_shape [channel_index - 1 ]:
136141 raise ValueError ("The color channel index of the images and the patch have to be identical." )
137- if y is not None :
138- raise ValueError ("The DPatch attack does not use target labels." )
139142 if x .ndim != 4 : # pragma: no cover
140143 raise ValueError ("The adversarial patch can only be applied to images." )
141144 if target_label is not None :
@@ -160,7 +163,7 @@ def generate( # pylint: disable=W0221
160163 )
161164 patch_target : List [Dict [str , np .ndarray ]] = []
162165
163- if self .target_label :
166+ if self .target_label and y is None :
164167
165168 for i_image in range (patched_images .shape [0 ]):
166169 if isinstance (self .target_label , int ):
@@ -190,7 +193,7 @@ def generate( # pylint: disable=W0221
190193
191194 else :
192195
193- predictions = self .estimator .predict (x = patched_images , standardise_output = True )
196+ predictions = y if y is not None else self .estimator .predict (x = patched_images , standardise_output = True )
194197
195198 for i_image in range (patched_images .shape [0 ]):
196199 target_dict = {}
0 commit comments