@@ -111,38 +111,36 @@ def __init__(
111111 "Unexpected input_shape in estimator detected. AdversarialPatch is expecting images or videos as input."
112112 )
113113
114- self .image_shape = self .estimator .input_shape
114+ self .input_shape = self .estimator .input_shape
115115
116- self .i_h_patch = 0
117- self .i_w_patch = 1
118-
119- self .nb_dims = len (self .image_shape )
116+ self .nb_dims = len (self .input_shape )
120117 if self .nb_dims == 3 :
121118 if self .estimator .channels_first :
119+ self .i_c = 0
122120 self .i_h = 1
123121 self .i_w = 2
124122 else :
125123 self .i_h = 0
126124 self .i_w = 1
125+ self .i_c = 2
127126 elif self .nb_dims == 4 :
128127 if self .estimator .channels_first :
128+ self .i_c = 1
129129 self .i_h = 2
130130 self .i_w = 3
131131 else :
132132 self .i_h = 1
133133 self .i_w = 2
134+ self .i_c = 3
135+
136+ smallest_image_edge = np .minimum (self .input_shape [self .i_h ], self .input_shape [self .i_w ])
137+ nb_channels = self .input_shape [self .i_c ]
134138
135139 if self .estimator .channels_first :
136- smallest_image_edge = np .minimum (self .image_shape [1 ], self .image_shape [2 ])
137- nb_channels = self .image_shape [0 ]
138140 self .patch_shape = (nb_channels , smallest_image_edge , smallest_image_edge )
139141 else :
140- smallest_image_edge = np .minimum (self .image_shape [0 ], self .image_shape [1 ])
141- nb_channels = self .image_shape [2 ]
142142 self .patch_shape = (smallest_image_edge , smallest_image_edge , nb_channels )
143143
144- self .patch_shape = self .image_shape
145-
146144 self .patch = None
147145 self .mean_value = (
148146 self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]
@@ -167,6 +165,14 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
167165 """
168166 logger .info ("Creating adversarial patch." )
169167
168+ test_input_shape = list (self .estimator .input_shape )
169+
170+ for i , size in enumerate (self .estimator .input_shape ):
171+ if size is None or size != x .shape [i + 1 ]:
172+ test_input_shape [i ] = x .shape [i + 1 ]
173+
174+ self .input_shape = tuple (test_input_shape )
175+
170176 mask = kwargs .get ("mask" )
171177 if mask is not None :
172178 mask = mask .copy ()
@@ -214,6 +220,8 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
214220 patch_gradients_i = self ._reverse_transformation (
215221 gradients [i_image , :, :, :], patch_mask_transformed [i_image , :, :, :], transforms [i_image ],
216222 )
223+ if self .nb_dims == 4 :
224+ patch_gradients_i = np .mean (patch_gradients_i , axis = 0 )
217225 patch_gradients += patch_gradients_i
218226
219227 # patch_gradients = patch_gradients / (num_batches * self.batch_size)
@@ -282,7 +290,7 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
282290 """
283291 Return a circular patch mask
284292 """
285- diameter = np .minimum (self .patch_shape [self .i_h ], self .patch_shape [self .i_w ])
293+ diameter = np .minimum (self .input_shape [self .i_h ], self .input_shape [self .i_w ])
286294
287295 x = np .linspace (- 1 , 1 , diameter )
288296 y = np .linspace (- 1 , 1 , diameter )
@@ -294,26 +302,12 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
294302 channel_index = 1 if self .estimator .channels_first else 3
295303 axis = channel_index - 1
296304 mask = np .expand_dims (mask , axis = axis )
297- mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
298305
299- pad_h_before = int ((self .image_shape [self .i_h ] - mask .shape [self .i_h ]) / 2 )
300- pad_h_after = int (self .image_shape [self .i_h ] - pad_h_before - mask .shape [self .i_h ])
301-
302- pad_w_before = int ((self .image_shape [self .i_w ] - mask .shape [self .i_w ]) / 2 )
303- pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
304-
305- if self .estimator .channels_first :
306- if self .nb_dims == 3 :
307- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
308- elif self .nb_dims == 4 :
309- pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
310- else :
311- if self .nb_dims == 3 :
312- pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
313- elif self .nb_dims == 4 :
314- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
306+ mask = np .broadcast_to (mask , self .patch_shape ).astype (np .float32 )
315307
316- mask = np .pad (mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
308+ if self .nb_dims == 4 :
309+ mask = np .expand_dims (mask , axis = 0 )
310+ mask = np .repeat (mask , axis = 0 , repeats = self .input_shape [0 ]).astype (np .float32 )
317311
318312 return mask
319313
@@ -361,22 +355,18 @@ def _rotate(self, x, angle):
361355
362356 def _scale (self , x , scale ):
363357 zooms = None
364- height = None
365- width = None
358+ height , width = x . shape [ self . i_h ], x . shape [ self . i_w ]
359+
366360 if self .estimator .channels_first :
367361 if self .nb_dims == 3 :
368362 zooms = (1.0 , scale , scale )
369- height , width = self .patch_shape [1 :3 ]
370363 elif self .nb_dims == 4 :
371364 zooms = (1.0 , 1.0 , scale , scale )
372- height , width = self .patch_shape [2 :4 ]
373365 elif not self .estimator .channels_first :
374366 if self .nb_dims == 3 :
375367 zooms = (scale , scale , 1.0 )
376- height , width = self .patch_shape [0 :2 ]
377368 elif self .nb_dims == 4 :
378369 zooms = (1.0 , scale , scale , 1.0 )
379- height , width = self .patch_shape [1 :3 ]
380370
381371 if scale < 1.0 :
382372 scale_h = int (np .round (height * scale ))
@@ -457,6 +447,10 @@ def _random_transformation(self, patch, scale, mask_2d):
457447 patch_mask = self ._get_circular_patch_mask ()
458448 transformation = dict ()
459449
450+ if self .nb_dims == 4 :
451+ patch = np .expand_dims (patch , axis = 0 )
452+ patch = np .repeat (patch , axis = 0 , repeats = self .input_shape [0 ]).astype (np .float32 )
453+
460454 # rotate
461455 angle = random .uniform (- self .rotation_max , self .rotation_max )
462456 transformation ["rotate" ] = angle
@@ -470,10 +464,34 @@ def _random_transformation(self, patch, scale, mask_2d):
470464 patch_mask = self ._scale (patch_mask , scale )
471465 transformation ["scale" ] = scale
472466
467+ # pad
468+ pad_h_before = int ((self .input_shape [self .i_h ] - patch .shape [self .i_h ]) / 2 )
469+ pad_h_after = int (self .input_shape [self .i_h ] - pad_h_before - patch .shape [self .i_h ])
470+
471+ pad_w_before = int ((self .input_shape [self .i_w ] - patch .shape [self .i_w ]) / 2 )
472+ pad_w_after = int (self .input_shape [self .i_w ] - pad_w_before - patch .shape [self .i_w ])
473+
474+ if self .estimator .channels_first :
475+ if self .nb_dims == 3 :
476+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
477+ elif self .nb_dims == 4 :
478+ pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )) # type: ignore
479+ else :
480+ if self .nb_dims == 3 :
481+ pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
482+ elif self .nb_dims == 4 :
483+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 )) # type: ignore
484+
485+ transformation ["pad_h_before" ] = pad_h_before
486+ transformation ["pad_w_before" ] = pad_w_before
487+
488+ patch = np .pad (patch , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
489+ patch_mask = np .pad (patch_mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
490+
473491 # shift
474492 if mask_2d is None :
475- shift_max_h = (self .estimator . input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
476- shift_max_w = (self .estimator . input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
493+ shift_max_h = (self .input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
494+ shift_max_w = (self .input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
477495 if shift_max_h > 0 and shift_max_w > 0 :
478496 shift_h = random .uniform (- shift_max_h , shift_max_h )
479497 shift_w = random .uniform (- shift_max_w , shift_max_w )
@@ -496,8 +514,8 @@ def _random_transformation(self, patch, scale, mask_2d):
496514 num_pos = np .argwhere (mask_2d ).shape [0 ]
497515 pos_id = np .random .choice (num_pos , size = 1 )
498516 pos = np .argwhere (mask_2d )[pos_id [0 ]]
499- shift_h = pos [0 ] - (self .estimator . input_shape [self .i_h ]) / 2.0
500- shift_w = pos [1 ] - (self .estimator . input_shape [self .i_w ]) / 2.0
517+ shift_h = pos [0 ] - (self .input_shape [self .i_h ]) / 2.0
518+ shift_w = pos [1 ] - (self .input_shape [self .i_w ]) / 2.0
501519
502520 patch = self ._shift (patch , shift_h , shift_w )
503521 patch_mask = self ._shift (patch_mask , shift_h , shift_w )
@@ -515,6 +533,27 @@ def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed,
515533 shift_w = transformation ["shift_w" ]
516534 gradients = self ._shift (gradients , - shift_h , - shift_w )
517535
536+ # unpad
537+
538+ pad_h_before = transformation ["pad_h_before" ]
539+ pad_w_before = transformation ["pad_w_before" ]
540+
541+ if self .estimator .channels_first :
542+ height , width = self .patch_shape [1 ], self .patch_shape [2 ]
543+ else :
544+ height , width = self .patch_shape [0 ], self .patch_shape [1 ]
545+
546+ if self .estimator .channels_first :
547+ if self .nb_dims == 3 :
548+ gradients = gradients [:, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width ]
549+ elif self .nb_dims == 4 :
550+ gradients = gradients [:, :, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width ]
551+ else :
552+ if self .nb_dims == 3 :
553+ gradients = gradients [pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width , :]
554+ elif self .nb_dims == 4 :
555+ gradients = gradients [:, pad_h_before : pad_h_before + height , pad_w_before : pad_w_before + width , :]
556+
518557 # scale
519558 scale = transformation ["scale" ]
520559 gradients = self ._scale (gradients , 1.0 / scale )
0 commit comments