Skip to content

Commit 8171abe

Browse files
author
Beat Buesser
committed
Add support for rectangular input images
Signed-off-by: Beat Buesser <[email protected]>
1 parent 122ce3b commit 8171abe

File tree

1 file changed

+96
-58
lines changed

1 file changed

+96
-58
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_numpy.py

Lines changed: 96 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,33 @@ def __init__(
9999
self.learning_rate = learning_rate
100100
self.max_iter = max_iter
101101
self.batch_size = batch_size
102-
self.patch_shape = self.estimator.input_shape
103102
self.clip_patch = clip_patch
104103
self._check_params()
105104

105+
self.image_shape = self.estimator.input_shape
106+
107+
if self.estimator.channels_first:
108+
self.i_h = 1
109+
self.i_w = 2
110+
else:
111+
self.i_h = 0
112+
self.i_w = 1
113+
114+
if self.estimator.channels_first:
115+
smallest_image_edge = np.minimum(self.image_shape[1], self.image_shape[2])
116+
nb_channels = self.image_shape[0]
117+
self.patch_shape = (nb_channels, smallest_image_edge, smallest_image_edge)
118+
else:
119+
smallest_image_edge = np.minimum(self.image_shape[0], self.image_shape[1])
120+
nb_channels = self.image_shape[2]
121+
self.patch_shape = (smallest_image_edge, smallest_image_edge, nb_channels)
122+
123+
self.patch_shape = self.image_shape
124+
106125
mean_value = (self.estimator.clip_values[1] - self.estimator.clip_values[0]) / 2.0 + self.estimator.clip_values[
107126
0
108127
]
109128
self.patch = np.ones(shape=self.patch_shape).astype(np.float32) * mean_value
110-
self.patch[int(self.patch.shape[0] / 2), :, :] = 1
111129

112130
def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
113131
"""
@@ -203,17 +221,27 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
203221
"""
204222
Return a circular patch mask
205223
"""
206-
diameter = self.patch_shape[1]
224+
diameter = np.minimum(self.patch_shape[self.i_h], self.patch_shape[self.i_w])
225+
207226
x = np.linspace(-1, 1, diameter)
208227
y = np.linspace(-1, 1, diameter)
209228
x_grid, y_grid = np.meshgrid(x, y, sparse=True)
210229
z_grid = (x_grid ** 2 + y_grid ** 2) ** sharpness
211230

212231
mask = 1 - np.clip(z_grid, -1, 1)
213232

214-
pad_1 = int((self.patch_shape[1] - mask.shape[1]) / 2)
215-
pad_2 = int(self.patch_shape[1] - pad_1 - mask.shape[1])
216-
mask = np.pad(mask, pad_width=(pad_1, pad_2), mode="constant", constant_values=(0, 0))
233+
pad_h_before = int((self.image_shape[self.i_h] - mask.shape[self.i_h]) / 2)
234+
pad_h_after = int(self.image_shape[self.i_h] - pad_h_before - mask.shape[self.i_h])
235+
236+
pad_w_before = int((self.image_shape[self.i_w] - mask.shape[self.i_w]) / 2)
237+
pad_w_after = int(self.image_shape[self.i_w] - pad_w_before - mask.shape[self.i_w])
238+
239+
mask = np.pad(
240+
mask,
241+
pad_width=((pad_h_before, pad_h_after), (pad_w_before, pad_w_after)),
242+
mode="constant",
243+
constant_values=(0, 0),
244+
)
217245

218246
channel_index = 1 if self.estimator.channels_first else 3
219247
axis = channel_index - 1
@@ -250,57 +278,67 @@ def _augment_images_with_random_patch(self, images, patch, scale=None):
250278
return patched_images, patch_mask_transformed_np, transformations
251279

252280
def _rotate(self, x, angle):
253-
axes = None
254-
if not self.estimator.channels_first:
255-
axes = (0, 1)
256-
elif self.estimator.channels_first:
257-
axes = (1, 2)
281+
axes = (self.i_h, self.i_w)
258282
return rotate(x, angle=angle, reshape=False, axes=axes, order=1)
259283

260-
def _scale(self, x, scale, shape):
284+
def _scale(self, x, scale):
261285
zooms = None
262-
if not self.estimator.channels_first:
263-
zooms = (scale, scale, 1.0)
264-
elif self.estimator.channels_first:
286+
height = None
287+
width = None
288+
if self.estimator.channels_first:
265289
zooms = (1.0, scale, scale)
266-
x = zoom(x, zoom=zooms, order=1)
267-
268-
if x.shape[1] <= self.estimator.input_shape[1]:
269-
pad_1 = int((shape - x.shape[1]) / 2)
270-
pad_2 = int(shape - pad_1 - x.shape[1])
271-
if not self.estimator.channels_first:
272-
pad_width = ((pad_1, pad_2), (pad_1, pad_2), (0, 0))
273-
elif self.estimator.channels_first:
274-
pad_width = ((0, 0), (pad_1, pad_2), (pad_1, pad_2))
290+
height, width = self.patch_shape[1:3]
291+
elif not self.estimator.channels_first:
292+
zooms = (scale, scale, 1.0)
293+
height, width = self.patch_shape[0:2]
294+
295+
if scale < 1.0:
296+
scale_h = int(np.round(height * scale))
297+
scale_w = int(np.round(width * scale))
298+
top = (height - scale_h) // 2
299+
left = (width - scale_w) // 2
300+
301+
x_out = np.zeros_like(x)
302+
x_out[top : top + scale_h, left : left + scale_w] = zoom(x, zoom=zooms, order=1)
303+
304+
if self.estimator.channels_first:
305+
x_out[:, top : top + scale_h, left : left + scale_w] = zoom(x, zoom=zooms, order=1)
275306
else:
276-
pad_width = None
277-
x = np.pad(x, pad_width=pad_width, mode="constant", constant_values=(0, 0))
278-
else:
279-
center = int(x.shape[1] / 2)
280-
patch_hw_1 = int(self.estimator.input_shape[1] / 2)
281-
patch_hw_2 = self.estimator.input_shape[1] - patch_hw_1
282-
if not self.estimator.channels_first:
283-
x = x[center - patch_hw_1 : center + patch_hw_2, center - patch_hw_1 : center + patch_hw_2, :]
284-
elif self.estimator.channels_first:
285-
x = x[:, center - patch_hw_1 : center + patch_hw_2, center - patch_hw_1 : center + patch_hw_2]
307+
x_out[top : top + scale_h, left : left + scale_w, :] = zoom(x, zoom=zooms, order=1)
308+
309+
elif scale > 1.0:
310+
scale_h = int(np.round(height / scale))
311+
scale_w = int(np.round(width / scale))
312+
top = (height - scale_h) // 2
313+
left = (width - scale_w) // 2
314+
315+
x_out = zoom(x[top : top + scale_h, left : left + scale_w], zoom=zooms, order=1)
316+
317+
cut_top = (x_out.shape[self.i_h] - height) // 2
318+
cut_left = (x_out.shape[self.i_w] - width) // 2
319+
320+
if self.estimator.channels_first:
321+
x_out = x_out[:, cut_top : cut_top + height, cut_left : cut_left + width]
286322
else:
287-
x = None
323+
x_out = x_out[cut_top : cut_top + height, cut_left : cut_left + width, :]
324+
325+
else:
326+
x_out = x
288327

289-
return x
328+
return x_out
329+
330+
def _shift(self, x, shift_h, shift_w):
331+
if self.estimator.channels_first:
332+
shift_hw = (0, shift_h, shift_w)
333+
else:
334+
shift_hw = (shift_h, shift_w, 0)
290335

291-
def _shift(self, x, shift_1, shift_2):
292-
shift_xy = None
293-
if not self.estimator.channels_first:
294-
shift_xy = (shift_1, shift_2, 0)
295-
elif self.estimator.channels_first:
296-
shift_xy = (0, shift_1, shift_2)
297-
x = shift(x, shift=shift_xy, order=1)
298-
return x, shift_1, shift_2
336+
x = shift(x, shift=shift_hw, order=1)
337+
return x, shift_h, shift_w
299338

300339
def _random_transformation(self, patch, scale):
301340
patch_mask = self._get_circular_patch_mask()
302341
transformation = dict()
303-
shape = patch_mask.shape[1]
304342

305343
# rotate
306344
angle = random.uniform(-self.rotation_max, self.rotation_max)
@@ -311,17 +349,18 @@ def _random_transformation(self, patch, scale):
311349
# scale
312350
if scale is None:
313351
scale = random.uniform(self.scale_min, self.scale_max)
314-
patch = self._scale(patch, scale, shape)
315-
patch_mask = self._scale(patch_mask, scale, shape)
352+
patch = self._scale(patch, scale)
353+
patch_mask = self._scale(patch_mask, scale)
316354
transformation["scale"] = scale
317355

318356
# shift
319-
shift_max = (self.estimator.input_shape[1] - self.patch_shape[1] * scale) / 2.0
320-
if shift_max > 0:
321-
shift_1 = random.uniform(-shift_max, shift_max)
322-
shift_2 = random.uniform(-shift_max, shift_max)
323-
patch, _, _ = self._shift(patch, shift_1, shift_2)
324-
patch_mask, shift_1, shift_2 = self._shift(patch_mask, shift_1, shift_2)
357+
shift_max_h = (self.estimator.input_shape[self.i_h] - self.patch_shape[self.i_h] * scale) / 2.0
358+
shift_max_w = (self.estimator.input_shape[self.i_w] - self.patch_shape[self.i_w] * scale) / 2.0
359+
if shift_max_h > 0 and shift_max_w > 0:
360+
shift_h = random.uniform(-shift_max_h, shift_max_h)
361+
shift_w = random.uniform(-shift_max_w, shift_max_w)
362+
patch, _, _ = self._shift(patch, shift_h, shift_w)
363+
patch_mask, shift_1, shift_2 = self._shift(patch_mask, shift_h, shift_w)
325364
transformation["shift_1"] = shift_1
326365
transformation["shift_2"] = shift_2
327366
else:
@@ -330,17 +369,16 @@ def _random_transformation(self, patch, scale):
330369
return patch, patch_mask, transformation
331370

332371
def _reverse_transformation(self, gradients: np.ndarray, patch_mask_transformed, transformation) -> np.ndarray:
333-
shape = gradients.shape[1]
334372
gradients = gradients * patch_mask_transformed
335373

336374
# shift
337-
shift_1 = transformation["shift_1"]
338-
shift_2 = transformation["shift_2"]
339-
gradients, _, _ = self._shift(gradients, -shift_1, -shift_2)
375+
shift_h = transformation["shift_h"]
376+
shift_w = transformation["shift_w"]
377+
gradients, _, _ = self._shift(gradients, -shift_h, -shift_w)
340378

341379
# scale
342380
scale = transformation["scale"]
343-
gradients = self._scale(gradients, 1.0 / scale, shape)
381+
gradients = self._scale(gradients, 1.0 / scale)
344382

345383
# rotate
346384
angle = transformation["rotate"]

0 commit comments

Comments
 (0)