@@ -111,38 +111,36 @@ def __init__(
111111 "Unexpected input_shape in estimator detected. AdversarialPatch is expecting images or videos as input."
112112 )
113113
114- self .image_shape = self .estimator .input_shape
114+ self .input_shape = self .estimator .input_shape
115115
116- self .i_h_patch = 0
117- self .i_w_patch = 1
118-
119- self .nb_dims = len (self .image_shape )
116+ self .nb_dims = len (self .input_shape )
120117 if self .nb_dims == 3 :
121118 if self .estimator .channels_first :
119+ self .i_c = 0
122120 self .i_h = 1
123121 self .i_w = 2
124122 else :
125123 self .i_h = 0
126124 self .i_w = 1
125+ self .i_c = 2
127126 elif self .nb_dims == 4 :
128127 if self .estimator .channels_first :
128+ self .i_c = 1
129129 self .i_h = 2
130130 self .i_w = 3
131131 else :
132132 self .i_h = 1
133133 self .i_w = 2
134+ self .i_c = 3
135+
136+ smallest_image_edge = np .minimum (self .input_shape [self .i_h ], self .input_shape [self .i_w ])
137+ nb_channels = self .input_shape [self .i_c ]
134138
135139 if self .estimator .channels_first :
136- smallest_image_edge = np .minimum (self .image_shape [1 ], self .image_shape [2 ])
137- nb_channels = self .image_shape [0 ]
138140 self .patch_shape = (nb_channels , smallest_image_edge , smallest_image_edge )
139141 else :
140- smallest_image_edge = np .minimum (self .image_shape [0 ], self .image_shape [1 ])
141- nb_channels = self .image_shape [2 ]
142142 self .patch_shape = (smallest_image_edge , smallest_image_edge , nb_channels )
143143
144- self .patch_shape = self .image_shape
145-
146144 mean_value = (self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]) / 2.0 + self .estimator .clip_values [
147145 0
148146 ]
@@ -162,6 +160,14 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
162160 """
163161 logger .info ("Creating adversarial patch." )
164162
163+ test_input_shape = list (self .estimator .input_shape )
164+
165+ for i , size in enumerate (self .estimator .input_shape ):
166+ if size is None or size != x .shape [i + 1 ]:
167+ test_input_shape [i ] = x .shape [i + 1 ]
168+
169+ self .input_shape = tuple (test_input_shape )
170+
165171 mask = kwargs .get ("mask" )
166172 if mask is not None :
167173 mask = mask .copy ()
@@ -206,6 +212,8 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
206212 patch_gradients_i = self ._reverse_transformation (
207213 gradients [i_image , :, :, :], patch_mask_transformed [i_image , :, :, :], transforms [i_image ],
208214 )
215+ if self .nb_dims == 4 :
216+ patch_gradients_i = np .mean (patch_gradients_i , axis = 0 )
209217 patch_gradients += patch_gradients_i
210218
211219 # patch_gradients = patch_gradients / (num_batches * self.batch_size)
@@ -274,7 +282,7 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
274282 """
275283 Return a circular patch mask
276284 """
277- diameter = np .minimum (self .patch_shape [self .i_h ], self .patch_shape [self .i_w ])
285+ diameter = np .minimum (self .input_shape [self .i_h ], self .input_shape [self .i_w ])
278286
279287 x = np .linspace (- 1 , 1 , diameter )
280288 y = np .linspace (- 1 , 1 , diameter )
@@ -286,26 +294,12 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
286294 channel_index = 1 if self .estimator .channels_first else 3
287295 axis = channel_index - 1
288296 mask = np .expand_dims (mask , axis = axis )
289- mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
290297
291- pad_h_before = int ((self .image_shape [self .i_h ] - mask .shape [self .i_h ]) / 2 )
292- pad_h_after = int (self .image_shape [self .i_h ] - pad_h_before - mask .shape [self .i_h ])
293-
294- pad_w_before = int ((self .image_shape [self .i_w ] - mask .shape [self .i_w ]) / 2 )
295- pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
296-
297- if self .estimator .channels_first :
298- if self .nb_dims == 3 :
299- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
300- elif self .nb_dims == 4 :
301- pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
302- else :
303- if self .nb_dims == 3 :
304- pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
305- elif self .nb_dims == 4 :
306- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
298+ mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
307299
308- mask = np .pad (mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
300+ if self .nb_dims == 4 :
301+ mask = np .expand_dims (mask , axis = 0 )
302+ mask = np .repeat (mask , axis = 0 , repeats = self .input_shape [0 ]).astype (np .float32 )
309303
310304 return mask
311305
@@ -353,22 +347,18 @@ def _rotate(self, x, angle):
353347
354348 def _scale (self , x , scale ):
355349 zooms = None
356- height = None
357- width = None
350+ height , width = x . shape [ self . i_h ], x . shape [ self . i_w ]
351+
358352 if self .estimator .channels_first :
359353 if self .nb_dims == 3 :
360354 zooms = (1.0 , scale , scale )
361- height , width = self .patch_shape [1 :3 ]
362355 elif self .nb_dims == 4 :
363356 zooms = (1.0 , 1.0 , scale , scale )
364- height , width = self .patch_shape [2 :4 ]
365357 elif not self .estimator .channels_first :
366358 if self .nb_dims == 3 :
367359 zooms = (scale , scale , 1.0 )
368- height , width = self .patch_shape [0 :2 ]
369360 elif self .nb_dims == 4 :
370361 zooms = (1.0 , scale , scale , 1.0 )
371- height , width = self .patch_shape [1 :3 ]
372362
373363 if scale < 1.0 :
374364 scale_h = int (np .round (height * scale ))
@@ -449,6 +439,10 @@ def _random_transformation(self, patch, scale, mask_2d):
449439 patch_mask = self ._get_circular_patch_mask ()
450440 transformation = dict ()
451441
442+ if self .nb_dims == 4 :
443+ patch = np .expand_dims (patch , axis = 0 )
444+ patch = np .repeat (patch , axis = 0 , repeats = self .input_shape [0 ]).astype (np .float32 )
445+
452446 # rotate
453447 angle = random .uniform (- self .rotation_max , self .rotation_max )
454448 transformation ["rotate" ] = angle
@@ -462,10 +456,34 @@ def _random_transformation(self, patch, scale, mask_2d):
462456 patch_mask = self ._scale (patch_mask , scale )
463457 transformation ["scale" ] = scale
464458
459+ # pad
460+ pad_h_before = int ((self .input_shape [self .i_h ] - patch .shape [self .i_h ]) / 2 )
461+ pad_h_after = int (self .input_shape [self .i_h ] - pad_h_before - patch .shape [self .i_h ])
462+
463+ pad_w_before = int ((self .input_shape [self .i_w ] - patch .shape [self .i_w ]) / 2 )
464+ pad_w_after = int (self .input_shape [self .i_w ] - pad_w_before - patch .shape [self .i_w ])
465+
466+ if self .estimator .channels_first :
467+ if self .nb_dims == 3 :
468+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
469+ elif self .nb_dims == 4 :
470+ pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
471+ else :
472+ if self .nb_dims == 3 :
473+ pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
474+ elif self .nb_dims == 4 :
475+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
476+
477+ transformation ["pad_h_before" ] = pad_h_before
478+ transformation ["pad_w_before" ] = pad_w_before
479+
480+ patch = np .pad (patch , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
481+ patch_mask = np .pad (patch_mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
482+
465483 # shift
466484 if mask_2d is None :
467- shift_max_h = (self .estimator . input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
468- shift_max_w = (self .estimator . input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
485+ shift_max_h = (self .input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
486+ shift_max_w = (self .input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
469487 if shift_max_h > 0 and shift_max_w > 0 :
470488 shift_h = random .uniform (- shift_max_h , shift_max_h )
471489 shift_w = random .uniform (- shift_max_w , shift_max_w )
@@ -488,8 +506,8 @@ def _random_transformation(self, patch, scale, mask_2d):
488506 num_pos = np .argwhere (mask_2d ).shape [0 ]
489507 pos_id = np .random .choice (num_pos , size = 1 )
490508 pos = np .argwhere (mask_2d )[pos_id [0 ]]
491- shift_h = pos [0 ] - (self .estimator . input_shape [self .i_h ]) / 2.0
492- shift_w = pos [1 ] - (self .estimator . input_shape [self .i_w ]) / 2.0
509+ shift_h = pos [0 ] - (self .input_shape [self .i_h ]) / 2.0
510+ shift_w = pos [1 ] - (self .input_shape [self .i_w ]) / 2.0
493511
494512 patch = self ._shift (patch , shift_h , shift_w )
495513 patch_mask = self ._shift (patch_mask , shift_h , shift_w )
@@ -507,6 +525,27 @@ def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed,
507525 shift_w = transformation ["shift_w" ]
508526 gradients = self ._shift (gradients , - shift_h , - shift_w )
509527
528+ # unpad
529+
530+ pad_h_before = transformation ["pad_h_before" ]
531+ pad_w_before = transformation ["pad_w_before" ]
532+
533+ if self .estimator .channels_first :
534+ height , width = self .patch_shape [1 ], self .patch_shape [2 ]
535+ else :
536+ height , width = self .patch_shape [0 ], self .patch_shape [1 ]
537+
538+ if self .estimator .channels_first :
539+ if self .nb_dims == 3 :
540+ gradients = gradients [:, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width ]
541+ elif self .nb_dims == 4 :
542+ gradients = gradients [:, :, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width ]
543+ else :
544+ if self .nb_dims == 3 :
545+ gradients = gradients [pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width , :]
546+ elif self .nb_dims == 4 :
547+ gradients = gradients [:, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width , :]
548+
510549 # scale
511550 scale = transformation ["scale" ]
512551 gradients = self ._scale (gradients , 1.0 / scale )
0 commit comments