@@ -103,17 +103,31 @@ def __init__(
103103 self .clip_patch = clip_patch
104104 self ._check_params ()
105105
106- if len (self .estimator .input_shape ) not in [3 ]:
107- raise ValueError ("Wrong input_shape in estimator detected. AdversarialPatch is expecting images as input." )
106+ if len (self .estimator .input_shape ) not in [3 , 4 ]:
107+ raise ValueError (
108+ "Unexpected input_shape in estimator detected. AdversarialPatch is expecting images or videos as input."
109+ )
108110
109111 self .image_shape = self .estimator .input_shape
110112
111- if self .estimator .channels_first :
112- self .i_h = 1
113- self .i_w = 2
114- else :
115- self .i_h = 0
116- self .i_w = 1
113+ self .i_h_patch = 0
114+ self .i_w_patch = 1
115+
116+ self .nb_dims = len (self .image_shape )
117+ if self .nb_dims == 3 :
118+ if self .estimator .channels_first :
119+ self .i_h = 1
120+ self .i_w = 2
121+ else :
122+ self .i_h = 0
123+ self .i_w = 1
124+ elif self .nb_dims == 4 :
125+ if self .estimator .channels_first :
126+ self .i_h = 2
127+ self .i_w = 3
128+ else :
129+ self .i_h = 1
130+ self .i_w = 2
117131
118132 if self .estimator .channels_first :
119133 smallest_image_edge = np .minimum (self .image_shape [1 ], self .image_shape [2 ])
@@ -246,9 +260,15 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
246260 pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
247261
248262 if self .estimator .channels_first :
249- pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ))
263+ if self .nb_dims == 3 :
264+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ))
265+ elif self .nb_dims == 4 :
266+ pad_width = ((0 , 0 ), (0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ))
250267 else :
251- pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 ))
268+ if self .nb_dims == 3 :
269+ pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 ))
270+ elif self .nb_dims == 4 :
271+ pad_width = ((0 , 0 ), (pad_h_before , pad_h_after ), (pad_w_before , pad_w_after ), (0 , 0 ))
252272
253273 mask = np .pad (mask , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ),)
254274
@@ -291,11 +311,19 @@ def _scale(self, x, scale):
291311 height = None
292312 width = None
293313 if self .estimator .channels_first :
294- zooms = (1.0 , scale , scale )
295- height , width = self .patch_shape [1 :3 ]
314+ if self .nb_dims == 3 :
315+ zooms = (1.0 , scale , scale )
316+ height , width = self .patch_shape [1 :3 ]
317+ elif self .nb_dims == 4 :
318+ zooms = (1.0 , 1.0 , scale , scale )
319+ height , width = self .patch_shape [2 :4 ]
296320 elif not self .estimator .channels_first :
297- zooms = (scale , scale , 1.0 )
298- height , width = self .patch_shape [0 :2 ]
321+ if self .nb_dims == 3 :
322+ zooms = (scale , scale , 1.0 )
323+ height , width = self .patch_shape [0 :2 ]
324+ elif self .nb_dims == 4 :
325+ zooms = (1.0 , scale , scale , 1.0 )
326+ height , width = self .patch_shape [1 :3 ]
299327
300328 if scale < 1.0 :
301329 scale_h = int (np .round (height * scale ))
@@ -306,9 +334,15 @@ def _scale(self, x, scale):
306334 x_out = np .zeros_like (x )
307335
308336 if self .estimator .channels_first :
309- x_out [:, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
337+ if self .nb_dims == 3 :
338+ x_out [:, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
339+ elif self .nb_dims == 4 :
340+ x_out [:, :, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
310341 else :
311- x_out [top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
342+ if self .nb_dims == 3 :
343+ x_out [top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
344+ elif self .nb_dims == 4 :
345+ x_out [:, top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
312346
313347 elif scale > 1.0 :
314348 scale_h = int (np .round (height / scale )) + 1
@@ -317,17 +351,29 @@ def _scale(self, x, scale):
317351 left = (width - scale_w ) // 2
318352
319353 if self .estimator .channels_first :
320- x_out = zoom (x [:, top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
354+ if self .nb_dims == 3 :
355+ x_out = zoom (x [:, top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
356+ elif self .nb_dims == 4 :
357+ x_out = zoom (x [:, :, top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
321358 else :
322- x_out = zoom (x [top : top + scale_h , left : left + scale_w , :], zoom = zooms , order = 1 )
359+ if self .nb_dims == 3 :
360+ x_out = zoom (x [top : top + scale_h , left : left + scale_w , :], zoom = zooms , order = 1 )
361+ elif self .nb_dims == 4 :
362+ x_out = zoom (x [:, top : top + scale_h , left : left + scale_w , :], zoom = zooms , order = 1 )
323363
324364 cut_top = (x_out .shape [self .i_h ] - height ) // 2
325365 cut_left = (x_out .shape [self .i_w ] - width ) // 2
326366
327367 if self .estimator .channels_first :
328- x_out = x_out [:, cut_top : cut_top + height , cut_left : cut_left + width ]
368+ if self .nb_dims == 3 :
369+ x_out = x_out [:, cut_top : cut_top + height , cut_left : cut_left + width ]
370+ elif self .nb_dims == 4 :
371+ x_out = x_out [:, :, cut_top : cut_top + height , cut_left : cut_left + width ]
329372 else :
330- x_out = x_out [cut_top : cut_top + height , cut_left : cut_left + width , :]
373+ if self .nb_dims == 3 :
374+ x_out = x_out [cut_top : cut_top + height , cut_left : cut_left + width , :]
375+ elif self .nb_dims == 4 :
376+ x_out = x_out [:, cut_top : cut_top + height , cut_left : cut_left + width , :]
331377
332378 else :
333379 x_out = x
@@ -338,9 +384,15 @@ def _scale(self, x, scale):
338384
339385 def _shift (self , x , shift_h , shift_w ):
340386 if self .estimator .channels_first :
341- shift_hw = (0 , shift_h , shift_w )
387+ if self .nb_dims == 3 :
388+ shift_hw = (0 , shift_h , shift_w )
389+ elif self .nb_dims == 4 :
390+ shift_hw = (0 , 0 , shift_h , shift_w )
342391 else :
343- shift_hw = (shift_h , shift_w , 0 )
392+ if self .nb_dims == 3 :
393+ shift_hw = (shift_h , shift_w , 0 )
394+ elif self .nb_dims == 4 :
395+ shift_hw = (0 , shift_h , shift_w , 0 )
344396 return shift (x , shift = shift_hw , order = 1 )
345397
346398 def _random_transformation (self , patch , scale ):
0 commit comments