2424from __future__ import absolute_import , division , print_function , unicode_literals
2525
2626import logging
27+ import math
2728from typing import Optional , Union
2829
2930import random
@@ -102,6 +103,9 @@ def __init__(
102103 self .clip_patch = clip_patch
103104 self ._check_params ()
104105
106+ if len (self .estimator .input_shape ) not in [3 ]:
107+ raise ValueError ("Wrong input_shape in estimator detected. AdversarialPatch is expecting images as input." )
108+
105109 self .image_shape = self .estimator .input_shape
106110
107111 if self .estimator .channels_first :
@@ -148,7 +152,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
148152 for _ in trange (self .max_iter , desc = "Adversarial Patch Numpy" ):
149153 patched_images , patch_mask_transformed , transforms = self ._augment_images_with_random_patch (x , self .patch )
150154
151- num_batches = int (x .shape [0 ] / self .batch_size )
155+ num_batches = int (math . ceil ( x .shape [0 ] / self .batch_size ) )
152156 patch_gradients = np .zeros_like (self .patch )
153157
154158 for i_batch in range (num_batches ):
@@ -159,7 +163,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
159163 patched_images [i_batch_start :i_batch_end ], y_target [i_batch_start :i_batch_end ],
160164 )
161165
162- for i_image in range (self . batch_size ):
166+ for i_image in range (gradients . shape [ 0 ] ):
163167 patch_gradients_i = self ._reverse_transformation (
164168 gradients [i_image , :, :, :], patch_mask_transformed [i_image , :, :, :], transforms [i_image ],
165169 )
@@ -230,23 +234,24 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
230234
231235 mask = 1 - np .clip (z_grid , - 1 , 1 )
232236
237+ channel_index = 1 if self .estimator .channels_first else 3
238+ axis = channel_index - 1
239+ mask = np .expand_dims (mask , axis = axis )
240+ mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
241+
233242 pad_h_before = int ((self .image_shape [self .i_h ] - mask .shape [self .i_h ]) / 2 )
234243 pad_h_after = int (self .image_shape [self .i_h ] - pad_h_before - mask .shape [self .i_h ])
235244
236245 pad_w_before = int ((self .image_shape [self .i_w ] - mask .shape [self .i_w ]) / 2 )
237246 pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
238247
239- mask = np . pad (
240- mask ,
241- pad_width = (( pad_h_before , pad_h_after ), ( pad_w_before , pad_w_after )),
242- mode = "constant" ,
243- constant_values = ( 0 , 0 ),
244- )
248+ if self . estimator . channels_first :
249+ pad_width = (( 0 , 0 ), ( pad_h_before , pad_h_after ), ( pad_w_before , pad_w_after ))
250+ else :
251+ pad_width = (( pad_h_before , pad_h_after ), ( pad_w_before , pad_w_after ), ( 0 , 0 ))
252+
253+ mask = np . pad ( mask , pad_width = pad_width , mode = "constant" , constant_values = ( 0 , 0 ), )
245254
246- channel_index = 1 if self .estimator .channels_first else 3
247- axis = channel_index - 1
248- mask = np .expand_dims (mask , axis = axis )
249- mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
250255 return mask
251256
252257 def _augment_images_with_random_patch (self , images , patch , scale = None ):
@@ -299,20 +304,22 @@ def _scale(self, x, scale):
299304 left = (width - scale_w ) // 2
300305
301306 x_out = np .zeros_like (x )
302- x_out [top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
303307
304308 if self .estimator .channels_first :
305309 x_out [:, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
306310 else :
307311 x_out [top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
308312
309313 elif scale > 1.0 :
310- scale_h = int (np .round (height / scale ))
311- scale_w = int (np .round (width / scale ))
314+ scale_h = int (np .round (height / scale )) + 1
315+ scale_w = int (np .round (width / scale )) + 1
312316 top = (height - scale_h ) // 2
313317 left = (width - scale_w ) // 2
314318
315- x_out = zoom (x [top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
319+ if self .estimator .channels_first :
320+ x_out = zoom (x [:, top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
321+ else :
322+ x_out = zoom (x [top : top + scale_h , left : left + scale_w , :], zoom = zooms , order = 1 )
316323
317324 cut_top = (x_out .shape [self .i_h ] - height ) // 2
318325 cut_left = (x_out .shape [self .i_w ] - width ) // 2
@@ -325,16 +332,16 @@ def _scale(self, x, scale):
325332 else :
326333 x_out = x
327334
335+ assert x .shape == x_out .shape
336+
328337 return x_out
329338
330339 def _shift (self , x , shift_h , shift_w ):
331340 if self .estimator .channels_first :
332341 shift_hw = (0 , shift_h , shift_w )
333342 else :
334343 shift_hw = (shift_h , shift_w , 0 )
335-
336- x = shift (x , shift = shift_hw , order = 1 )
337- return x , shift_h , shift_w
344+ return shift (x , shift = shift_hw , order = 1 )
338345
339346 def _random_transformation (self , patch , scale ):
340347 patch_mask = self ._get_circular_patch_mask ()
@@ -359,12 +366,13 @@ def _random_transformation(self, patch, scale):
359366 if shift_max_h > 0 and shift_max_w > 0 :
360367 shift_h = random .uniform (- shift_max_h , shift_max_h )
361368 shift_w = random .uniform (- shift_max_w , shift_max_w )
362- patch , _ , _ = self ._shift (patch , shift_h , shift_w )
363- patch_mask , shift_1 , shift_2 = self ._shift (patch_mask , shift_h , shift_w )
364- transformation ["shift_1 " ] = shift_1
365- transformation ["shift_2 " ] = shift_2
369+ patch = self ._shift (patch , shift_h , shift_w )
370+ patch_mask = self ._shift (patch_mask , shift_h , shift_w )
371+ transformation ["shift_h " ] = shift_h
372+ transformation ["shift_w " ] = shift_w
366373 else :
367- transformation ["shift" ] = (0 , 0 , 0 )
374+ transformation ["shift_h" ] = 0
375+ transformation ["shift_w" ] = 0
368376
369377 return patch , patch_mask , transformation
370378
@@ -374,7 +382,7 @@ def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed,
374382 # shift
375383 shift_h = transformation ["shift_h" ]
376384 shift_w = transformation ["shift_w" ]
377- gradients , _ , _ = self ._shift (gradients , - shift_h , - shift_w )
385+ gradients = self ._shift (gradients , - shift_h , - shift_w )
378386
379387 # scale
380388 scale = transformation ["scale" ]
@@ -383,4 +391,5 @@ def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed,
383391 # rotate
384392 angle = transformation ["rotate" ]
385393 gradients = self ._rotate (gradients , - angle )
394+
386395 return gradients
0 commit comments