@@ -111,57 +111,68 @@ def __init__(
111111 "Unexpected input_shape in estimator detected. AdversarialPatch is expecting images or videos as input."
112112 )
113113
114- self .image_shape = self .estimator .input_shape
114+ self .input_shape = self .estimator .input_shape
115115
116- self .i_h_patch = 0
117- self .i_w_patch = 1
118-
119- self .nb_dims = len (self .image_shape )
116+ self .nb_dims = len (self .input_shape )
120117 if self .nb_dims == 3 :
121118 if self .estimator .channels_first :
119+ self .i_c = 0
122120 self .i_h = 1
123121 self .i_w = 2
124122 else :
125123 self .i_h = 0
126124 self .i_w = 1
125+ self .i_c = 2
127126 elif self .nb_dims == 4 :
128127 if self .estimator .channels_first :
128+ self .i_c = 1
129129 self .i_h = 2
130130 self .i_w = 3
131131 else :
132132 self .i_h = 1
133133 self .i_w = 2
134+ self .i_c = 3
135+
136+ smallest_image_edge = np .minimum (self .input_shape [self .i_h ], self .input_shape [self .i_w ])
137+ nb_channels = self .input_shape [self .i_c ]
134138
135139 if self .estimator .channels_first :
136- smallest_image_edge = np .minimum (self .image_shape [1 ], self .image_shape [2 ])
137- nb_channels = self .image_shape [0 ]
138140 self .patch_shape = (nb_channels , smallest_image_edge , smallest_image_edge )
139141 else :
140- smallest_image_edge = np .minimum (self .image_shape [0 ], self .image_shape [1 ])
141- nb_channels = self .image_shape [2 ]
142142 self .patch_shape = (smallest_image_edge , smallest_image_edge , nb_channels )
143143
144- self .patch_shape = self .image_shape
145-
146- mean_value = (self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]) / 2.0 + self .estimator .clip_values [
147- 0
148- ]
149- self .patch = np .ones (shape = self .patch_shape ).astype (np .float32 ) * mean_value
144+ self .patch = None
145+ self .mean_value = (
146+ self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]
147+ ) / 2.0 + self .estimator .clip_values [0 ]
148+ self .reset_patch (self .mean_value )
150149
151150 def generate (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> Tuple [np .ndarray , np .ndarray ]:
152151 """
153152 Generate an adversarial patch and return the patch and its mask in arrays.
154153
155154 :param x: An array with the original input images of shape NHWC or NCHW or input videos of shape NFHWC or NFCHW.
156155 :param y: An array with the original true labels.
157- :param mask: An boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
156+ :param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
158157 (N, H, W) without their channel dimensions. Any features for which the mask is True can be the
159158 center location of the patch during sampling.
160159 :type mask: `np.ndarray`
160+ :param reset_patch: If `True` reset patch to initial values of mean of minimal and maximal clip value, else if
161+ `False` (default) restart from previous patch values created by previous call to `generate`
162+ or mean of minimal and maximal clip value if first call to `generate`.
163+ :type reset_patch: bool
161164 :return: An array with adversarial patch and an array of the patch mask.
162165 """
163166 logger .info ("Creating adversarial patch." )
164167
168+ test_input_shape = list (self .estimator .input_shape )
169+
170+ for i , size in enumerate (self .estimator .input_shape ):
171+ if size is None or size != x .shape [i + 1 ]:
172+ test_input_shape [i ] = x .shape [i + 1 ]
173+
174+ self .input_shape = tuple (test_input_shape )
175+
165176 mask = kwargs .get ("mask" )
166177 if mask is not None :
167178 mask = mask .copy ()
@@ -184,6 +195,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
184195 "dimensions."
185196 )
186197
198+ if kwargs .get ("reset_patch" ):
199+ self ._reset_patch ()
200+
187201 y_target = check_and_transform_label_format (labels = y , nb_classes = self .estimator .nb_classes )
188202
189203 for _ in trange (self .max_iter , desc = "Adversarial Patch Numpy" , disable = not self .verbose ):
@@ -206,6 +220,8 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
206220 patch_gradients_i = self ._reverse_transformation (
207221 gradients [i_image , :, :, :], patch_mask_transformed [i_image , :, :, :], transforms [i_image ],
208222 )
223+ if self .nb_dims == 4 :
224+ patch_gradients_i = np .mean (patch_gradients_i , axis = 0 )
209225 patch_gradients += patch_gradients_i
210226
211227 # patch_gradients = patch_gradients / (num_batches * self.batch_size)
@@ -274,7 +290,7 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
274290 """
275291 Return a circular patch mask
276292 """
277- diameter = np .minimum (self .patch_shape [self .i_h ], self .patch_shape [self .i_w ])
293+ diameter = np .minimum (self .input_shape [self .i_h ], self .input_shape [self .i_w ])
278294
279295 x = np .linspace (- 1 , 1 , diameter )
280296 y = np .linspace (- 1 , 1 , diameter )
@@ -286,26 +302,12 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
286302 channel_index = 1 if self .estimator .channels_first else 3
287303 axis = channel_index - 1
288304 mask = np .expand_dims (mask , axis = axis )
289- mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
290-
291- pad_h_before = int ((self .image_shape [self .i_h ] - mask .shape [self .i_h ]) / 2 )
292- pad_h_after = int (self .image_shape [self .i_h ] - pad_h_before - mask .shape [self .i_h ])
293-
294- pad_w_before = int ((self .image_shape [self .i_w ] - mask .shape [self .i_w ]) / 2 )
295- pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
296305
297- if self .estimator .channels_first :
298- if self .nb_dims == 3 :
299- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
300- elif self .nb_dims == 4 :
301- pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
302- else :
303- if self .nb_dims == 3 :
304- pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
305- elif self .nb_dims == 4 :
306- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
306+ mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
307307
308- mask = np .pad (mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
308+ if self .nb_dims == 4 :
309+ mask = np .expand_dims (mask , axis = 0 )
310+ mask = np .repeat (mask , axis = 0 , repeats = self .input_shape [0 ]).astype (np .float32 )
309311
310312 return mask
311313
@@ -353,22 +355,18 @@ def _rotate(self, x, angle):
353355
354356 def _scale (self , x , scale ):
355357 zooms = None
356- height = None
357- width = None
358+ height , width = x . shape [ self . i_h ], x . shape [ self . i_w ]
359+
358360 if self .estimator .channels_first :
359361 if self .nb_dims == 3 :
360362 zooms = (1.0 , scale , scale )
361- height , width = self .patch_shape [1 :3 ]
362363 elif self .nb_dims == 4 :
363364 zooms = (1.0 , 1.0 , scale , scale )
364- height , width = self .patch_shape [2 :4 ]
365365 elif not self .estimator .channels_first :
366366 if self .nb_dims == 3 :
367367 zooms = (scale , scale , 1.0 )
368- height , width = self .patch_shape [0 :2 ]
369368 elif self .nb_dims == 4 :
370369 zooms = (1.0 , scale , scale , 1.0 )
371- height , width = self .patch_shape [1 :3 ]
372370
373371 if scale < 1.0 :
374372 scale_h = int (np .round (height * scale ))
@@ -449,6 +447,10 @@ def _random_transformation(self, patch, scale, mask_2d):
449447 patch_mask = self ._get_circular_patch_mask ()
450448 transformation = dict ()
451449
450+ if self .nb_dims == 4 :
451+ patch = np .expand_dims (patch , axis = 0 )
452+ patch = np .repeat (patch , axis = 0 , repeats = self .input_shape [0 ]).astype (np .float32 )
453+
452454 # rotate
453455 angle = random .uniform (- self .rotation_max , self .rotation_max )
454456 transformation ["rotate" ] = angle
@@ -462,10 +464,34 @@ def _random_transformation(self, patch, scale, mask_2d):
462464 patch_mask = self ._scale (patch_mask , scale )
463465 transformation ["scale" ] = scale
464466
467+ # pad
468+ pad_h_before = int ((self .input_shape [self .i_h ] - patch .shape [self .i_h ]) / 2 )
469+ pad_h_after = int (self .input_shape [self .i_h ] - pad_h_before - patch .shape [self .i_h ])
470+
471+ pad_w_before = int ((self .input_shape [self .i_w ] - patch .shape [self .i_w ]) / 2 )
472+ pad_w_after = int (self .input_shape [self .i_w ] - pad_w_before - patch .shape [self .i_w ])
473+
474+ if self .estimator .channels_first :
475+ if self .nb_dims == 3 :
476+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
477+ elif self .nb_dims == 4 :
478+ pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
479+ else :
480+ if self .nb_dims == 3 :
481+ pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
482+ elif self .nb_dims == 4 :
483+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
484+
485+ transformation ["pad_h_before" ] = pad_h_before
486+ transformation ["pad_w_before" ] = pad_w_before
487+
488+ patch = np .pad (patch , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
489+ patch_mask = np .pad (patch_mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
490+
465491 # shift
466492 if mask_2d is None :
467- shift_max_h = (self .estimator . input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
468- shift_max_w = (self .estimator . input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
493+ shift_max_h = (self .input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
494+ shift_max_w = (self .input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
469495 if shift_max_h > 0 and shift_max_w > 0 :
470496 shift_h = random .uniform (- shift_max_h , shift_max_h )
471497 shift_w = random .uniform (- shift_max_w , shift_max_w )
@@ -488,8 +514,8 @@ def _random_transformation(self, patch, scale, mask_2d):
488514 num_pos = np .argwhere (mask_2d ).shape [0 ]
489515 pos_id = np .random .choice (num_pos , size = 1 )
490516 pos = np .argwhere (mask_2d )[pos_id [0 ]]
491- shift_h = pos [0 ] - (self .estimator . input_shape [self .i_h ]) / 2.0
492- shift_w = pos [1 ] - (self .estimator . input_shape [self .i_w ]) / 2.0
517+ shift_h = pos [0 ] - (self .input_shape [self .i_h ]) / 2.0
518+ shift_w = pos [1 ] - (self .input_shape [self .i_w ]) / 2.0
493519
494520 patch = self ._shift (patch , shift_h , shift_w )
495521 patch_mask = self ._shift (patch_mask , shift_h , shift_w )
@@ -507,6 +533,27 @@ def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed,
507533 shift_w = transformation ["shift_w" ]
508534 gradients = self ._shift (gradients , - shift_h , - shift_w )
509535
536+ # unpad
537+
538+ pad_h_before = transformation ["pad_h_before" ]
539+ pad_w_before = transformation ["pad_w_before" ]
540+
541+ if self .estimator .channels_first :
542+ height , width = self .patch_shape [1 ], self .patch_shape [2 ]
543+ else :
544+ height , width = self .patch_shape [0 ], self .patch_shape [1 ]
545+
546+ if self .estimator .channels_first :
547+ if self .nb_dims == 3 :
548+ gradients = gradients [:, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width ]
549+ elif self .nb_dims == 4 :
550+ gradients = gradients [:, :, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width ]
551+ else :
552+ if self .nb_dims == 3 :
553+ gradients = gradients [pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width , :]
554+ elif self .nb_dims == 4 :
555+ gradients = gradients [:, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width , :]
556+
510557 # scale
511558 scale = transformation ["scale" ]
512559 gradients = self ._scale (gradients , 1.0 / scale )
@@ -516,3 +563,18 @@ def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed,
516563 gradients = self ._rotate (gradients , - angle )
517564
518565 return gradients
566+
567+ def reset_patch (self , initial_patch_value : Optional [Union [float , np .ndarray ]]) -> None :
568+ """
569+ Reset the adversarial patch.
570+
571+ :param initial_patch_value: Patch value to use for resetting the patch.
572+ """
573+ if initial_patch_value is None :
574+ self .patch = np .ones (shape = self .patch_shape ).astype (np .float32 ) * self .mean_value
575+ elif isinstance (initial_patch_value , float ):
576+ self .patch = np .ones (shape = self .patch_shape ).astype (np .float32 ) * initial_patch_value
577+ elif self .patch .shape == initial_patch_value .shape :
578+ self .patch = initial_patch_value
579+ else :
580+ raise ValueError ("Unexpected value for initial_patch_value." )
0 commit comments