2424from __future__ import absolute_import , division , print_function , unicode_literals
2525
2626import logging
27+ import math
2728from typing import Optional , Union
2829
2930import random
@@ -100,11 +101,49 @@ def __init__(
100101 self .max_iter = max_iter
101102 self .batch_size = batch_size
102103 self .clip_patch = clip_patch
104+ self ._check_params ()
105+
106+ if len (self .estimator .input_shape ) not in [3 , 4 ]:
107+ raise ValueError (
108+ "Unexpected input_shape in estimator detected. AdversarialPatch is expecting images or videos as input."
109+ )
110+
111+ self .image_shape = self .estimator .input_shape
112+
113+ self .i_h_patch = 0
114+ self .i_w_patch = 1
115+
116+ self .nb_dims = len (self .image_shape )
117+ if self .nb_dims == 3 :
118+ if self .estimator .channels_first :
119+ self .i_h = 1
120+ self .i_w = 2
121+ else :
122+ self .i_h = 0
123+ self .i_w = 1
124+ elif self .nb_dims == 4 :
125+ if self .estimator .channels_first :
126+ self .i_h = 2
127+ self .i_w = 3
128+ else :
129+ self .i_h = 1
130+ self .i_w = 2
131+
132+ if self .estimator .channels_first :
133+ smallest_image_edge = np .minimum (self .image_shape [1 ], self .image_shape [2 ])
134+ nb_channels = self .image_shape [0 ]
135+ self .patch_shape = (nb_channels , smallest_image_edge , smallest_image_edge )
136+ else :
137+ smallest_image_edge = np .minimum (self .image_shape [0 ], self .image_shape [1 ])
138+ nb_channels = self .image_shape [2 ]
139+ self .patch_shape = (smallest_image_edge , smallest_image_edge , nb_channels )
140+
141+ self .patch_shape = self .image_shape
142+
103143 mean_value = (self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]) / 2.0 + self .estimator .clip_values [
104144 0
105145 ]
106- self .patch = np .ones (shape = self .estimator .input_shape ).astype (np .float32 ) * mean_value
107- self ._check_params ()
146+ self .patch = np .ones (shape = self .patch_shape ).astype (np .float32 ) * mean_value
108147
109148 def generate (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
110149 """
@@ -124,10 +163,10 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
124163
125164 y_target = check_and_transform_label_format (labels = y , nb_classes = self .estimator .nb_classes )
126165
127- for _ in trange (self .max_iter , desc = "Adversarial patch " ):
166+ for _ in trange (self .max_iter , desc = "Adversarial Patch Numpy " ):
128167 patched_images , patch_mask_transformed , transforms = self ._augment_images_with_random_patch (x , self .patch )
129168
130- num_batches = int (x .shape [0 ] / self .batch_size )
169+ num_batches = int (math . ceil ( x .shape [0 ] / self .batch_size ) )
131170 patch_gradients = np .zeros_like (self .patch )
132171
133172 for i_batch in range (num_batches ):
@@ -138,7 +177,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
138177 patched_images [i_batch_start :i_batch_end ], y_target [i_batch_start :i_batch_end ],
139178 )
140179
141- for i_image in range (self . batch_size ):
180+ for i_image in range (gradients . shape [ 0 ] ):
142181 patch_gradients_i = self ._reverse_transformation (
143182 gradients [i_image , :, :, :], patch_mask_transformed [i_image , :, :, :], transforms [i_image ],
144183 )
@@ -200,22 +239,39 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
200239 """
201240 Return a circular patch mask
202241 """
203- diameter = self .estimator .input_shape [1 ]
242+ diameter = np .minimum (self .patch_shape [self .i_h ], self .patch_shape [self .i_w ])
243+
204244 x = np .linspace (- 1 , 1 , diameter )
205245 y = np .linspace (- 1 , 1 , diameter )
206246 x_grid , y_grid = np .meshgrid (x , y , sparse = True )
207247 z_grid = (x_grid ** 2 + y_grid ** 2 ) ** sharpness
208248
209249 mask = 1 - np .clip (z_grid , - 1 , 1 )
210250
211- pad_1 = int ((self .estimator .input_shape [1 ] - mask .shape [1 ]) / 2 )
212- pad_2 = int (self .estimator .input_shape [1 ] - pad_1 - mask .shape [1 ])
213- mask = np .pad (mask , pad_width = (pad_1 , pad_2 ), mode = "constant" , constant_values = (0 , 0 ))
214-
215251 channel_index = 1 if self .estimator .channels_first else 3
216252 axis = channel_index - 1
217253 mask = np .expand_dims (mask , axis = axis )
218- mask = np .broadcast_to (mask , self .estimator .input_shape ).astype (np .float32 )
254+ mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
255+
256+ pad_h_before = int ((self .image_shape [self .i_h ] - mask .shape [self .i_h ]) / 2 )
257+ pad_h_after = int (self .image_shape [self .i_h ] - pad_h_before - mask .shape [self .i_h ])
258+
259+ pad_w_before = int ((self .image_shape [self .i_w ] - mask .shape [self .i_w ]) / 2 )
260+ pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
261+
262+ if self .estimator .channels_first :
263+ if self .nb_dims == 3 :
264+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ))
265+ elif self .nb_dims == 4 :
266+ pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ))
267+ else :
268+ if self .nb_dims == 3 :
269+ pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 ))
270+ elif self .nb_dims == 4 :
271+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 ))
272+
273+ mask = np .pad (mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
274+
219275 return mask
220276
221277 def _augment_images_with_random_patch (self , images , patch , scale = None ):
@@ -247,57 +303,106 @@ def _augment_images_with_random_patch(self, images, patch, scale=None):
247303 return patched_images , patch_mask_transformed_np , transformations
248304
249305 def _rotate (self , x , angle ):
250- axes = None
251- if not self .estimator .channels_first :
252- axes = (0 , 1 )
253- elif self .estimator .channels_first :
254- axes = (1 , 2 )
306+ axes = (self .i_h , self .i_w )
255307 return rotate (x , angle = angle , reshape = False , axes = axes , order = 1 )
256308
257- def _scale (self , x , scale , shape ):
309+ def _scale (self , x , scale ):
258310 zooms = None
259- if not self .estimator .channels_first :
260- zooms = (scale , scale , 1.0 )
261- elif self .estimator .channels_first :
262- zooms = (1.0 , scale , scale )
263- x = zoom (x , zoom = zooms , order = 1 )
264-
265- if x .shape [1 ] <= self .estimator .input_shape [1 ]:
266- pad_1 = int ((shape - x .shape [1 ]) / 2 )
267- pad_2 = int (shape - pad_1 - x .shape [1 ])
268- if not self .estimator .channels_first :
269- pad_width = ((pad_1 , pad_2 ), (pad_1 , pad_2 ), (0 , 0 ))
270- elif self .estimator .channels_first :
271- pad_width = ((0 , 0 ), (pad_1 , pad_2 ), (pad_1 , pad_2 ))
311+ height = None
312+ width = None
313+ if self .estimator .channels_first :
314+ if self .nb_dims == 3 :
315+ zooms = (1.0 , scale , scale )
316+ height , width = self .patch_shape [1 :3 ]
317+ elif self .nb_dims == 4 :
318+ zooms = (1.0 , 1.0 , scale , scale )
319+ height , width = self .patch_shape [2 :4 ]
320+ elif not self .estimator .channels_first :
321+ if self .nb_dims == 3 :
322+ zooms = (scale , scale , 1.0 )
323+ height , width = self .patch_shape [0 :2 ]
324+ elif self .nb_dims == 4 :
325+ zooms = (1.0 , scale , scale , 1.0 )
326+ height , width = self .patch_shape [1 :3 ]
327+
328+ if scale < 1.0 :
329+ scale_h = int (np .round (height * scale ))
330+ scale_w = int (np .round (width * scale ))
331+ top = (height - scale_h ) // 2
332+ left = (width - scale_w ) // 2
333+
334+ x_out = np .zeros_like (x )
335+
336+ if self .estimator .channels_first :
337+ if self .nb_dims == 3 :
338+ x_out [:, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
339+ elif self .nb_dims == 4 :
340+ x_out [:, :, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
272341 else :
273- pad_width = None
274- x = np .pad (x , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ))
275- else :
276- center = int (x .shape [1 ] / 2 )
277- patch_hw_1 = int (self .estimator .input_shape [1 ] / 2 )
278- patch_hw_2 = self .estimator .input_shape [1 ] - patch_hw_1
279- if not self .estimator .channels_first :
280- x = x [center - patch_hw_1 : center + patch_hw_2 , center - patch_hw_1 : center + patch_hw_2 , :]
281- elif self .estimator .channels_first :
282- x = x [:, center - patch_hw_1 : center + patch_hw_2 , center - patch_hw_1 : center + patch_hw_2 ]
342+ if self .nb_dims == 3 :
343+ x_out [top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
344+ elif self .nb_dims == 4 :
345+ x_out [:, top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
346+
347+ elif scale > 1.0 :
348+ scale_h = int (np .round (height / scale )) + 1
349+ scale_w = int (np .round (width / scale )) + 1
350+ top = (height - scale_h ) // 2
351+ left = (width - scale_w ) // 2
352+
353+ if scale_h <= height and scale_w <= width and top >= 0 and left >= 0 :
354+
355+ if self .estimator .channels_first :
356+ if self .nb_dims == 3 :
357+ x_out = zoom (x [:, top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
358+ elif self .nb_dims == 4 :
359+ x_out = zoom (x [:, :, top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
360+ else :
361+ if self .nb_dims == 3 :
362+ x_out = zoom (x [top : top + scale_h , left : left + scale_w , :], zoom = zooms , order = 1 )
363+ elif self .nb_dims == 4 :
364+ x_out = zoom (x [:, top : top + scale_h , left : left + scale_w , :], zoom = zooms , order = 1 )
365+
283366 else :
284- x = None
367+ x_out = x
285368
286- return x
369+ cut_top = (x_out .shape [self .i_h ] - height ) // 2
370+ cut_left = (x_out .shape [self .i_w ] - width ) // 2
287371
288- def _shift (self , x , shift_1 , shift_2 ):
289- shift_xy = None
290- if not self .estimator .channels_first :
291- shift_xy = (shift_1 , shift_2 , 0 )
292- elif self .estimator .channels_first :
293- shift_xy = (0 , shift_1 , shift_2 )
294- x = shift (x , shift = shift_xy , order = 1 )
295- return x , shift_1 , shift_2
372+ if self .estimator .channels_first :
373+ if self .nb_dims == 3 :
374+ x_out = x_out [:, cut_top : cut_top + height , cut_left : cut_left + width ]
375+ elif self .nb_dims == 4 :
376+ x_out = x_out [:, :, cut_top : cut_top + height , cut_left : cut_left + width ]
377+ else :
378+ if self .nb_dims == 3 :
379+ x_out = x_out [cut_top : cut_top + height , cut_left : cut_left + width , :]
380+ elif self .nb_dims == 4 :
381+ x_out = x_out [:, cut_top : cut_top + height , cut_left : cut_left + width , :]
382+
383+ else :
384+ x_out = x
385+
386+ assert x .shape == x_out .shape
387+
388+ return x_out
389+
390+ def _shift (self , x , shift_h , shift_w ):
391+ if self .estimator .channels_first :
392+ if self .nb_dims == 3 :
393+ shift_hw = (0 , shift_h , shift_w )
394+ elif self .nb_dims == 4 :
395+ shift_hw = (0 , 0 , shift_h , shift_w )
396+ else :
397+ if self .nb_dims == 3 :
398+ shift_hw = (shift_h , shift_w , 0 )
399+ elif self .nb_dims == 4 :
400+ shift_hw = (0 , shift_h , shift_w , 0 )
401+ return shift (x , shift = shift_hw , order = 1 )
296402
297403 def _random_transformation (self , patch , scale ):
298404 patch_mask = self ._get_circular_patch_mask ()
299405 transformation = dict ()
300- shape = patch_mask .shape [1 ]
301406
302407 # rotate
303408 angle = random .uniform (- self .rotation_max , self .rotation_max )
@@ -308,38 +413,40 @@ def _random_transformation(self, patch, scale):
308413 # scale
309414 if scale is None :
310415 scale = random .uniform (self .scale_min , self .scale_max )
311- patch = self ._scale (patch , scale , shape )
312- patch_mask = self ._scale (patch_mask , scale , shape )
416+ patch = self ._scale (patch , scale )
417+ patch_mask = self ._scale (patch_mask , scale )
313418 transformation ["scale" ] = scale
314419
315420 # shift
316- shift_max = (self .estimator .input_shape [1 ] * (1.0 - scale )) / 2.0
317- if shift_max > 0 :
318- shift_1 = random .uniform (- shift_max , shift_max )
319- shift_2 = random .uniform (- shift_max , shift_max )
320- patch , _ , _ = self ._shift (patch , shift_1 , shift_2 )
321- patch_mask , shift_1 , shift_2 = self ._shift (patch_mask , shift_1 , shift_2 )
322- transformation ["shift_1" ] = shift_1
323- transformation ["shift_2" ] = shift_2
421+ shift_max_h = (self .estimator .input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
422+ shift_max_w = (self .estimator .input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
423+ if shift_max_h > 0 and shift_max_w > 0 :
424+ shift_h = random .uniform (- shift_max_h , shift_max_h )
425+ shift_w = random .uniform (- shift_max_w , shift_max_w )
426+ patch = self ._shift (patch , shift_h , shift_w )
427+ patch_mask = self ._shift (patch_mask , shift_h , shift_w )
428+ transformation ["shift_h" ] = shift_h
429+ transformation ["shift_w" ] = shift_w
324430 else :
325- transformation ["shift" ] = (0 , 0 , 0 )
431+ transformation ["shift_h" ] = 0
432+ transformation ["shift_w" ] = 0
326433
327434 return patch , patch_mask , transformation
328435
329436 def _reverse_transformation (self , gradients : np .ndarray , patch_mask_transformed , transformation ) -> np .ndarray :
330- shape = gradients .shape [1 ]
331437 gradients = gradients * patch_mask_transformed
332438
333439 # shift
334- shift_1 = transformation ["shift_1 " ]
335- shift_2 = transformation ["shift_2 " ]
336- gradients , _ , _ = self ._shift (gradients , - shift_1 , - shift_2 )
440+ shift_h = transformation ["shift_h " ]
441+ shift_w = transformation ["shift_w " ]
442+ gradients = self ._shift (gradients , - shift_h , - shift_w )
337443
338444 # scale
339445 scale = transformation ["scale" ]
340- gradients = self ._scale (gradients , 1.0 / scale , shape )
446+ gradients = self ._scale (gradients , 1.0 / scale )
341447
342448 # rotate
343449 angle = transformation ["rotate" ]
344450 gradients = self ._rotate (gradients , - angle )
451+
345452 return gradients
0 commit comments