Skip to content

Commit b19917e

Browse files
authored
Merge pull request #1049 from Trusted-AI/development_issue_1047
Updates for DPatch fixing patch updating
2 parents b917a1b + 002a935 commit b19917e

File tree

1 file changed

+62
-33
lines changed

1 file changed

+62
-33
lines changed

art/attacks/evasion/dpatch.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ def generate( # pylint: disable=W0221
117117
mask = kwargs.get("mask")
118118
if mask is not None:
119119
mask = mask.copy()
120-
if (
121-
mask is not None
122-
and (mask.dtype != np.bool)
120+
if mask is not None and (
121+
mask.dtype != np.bool
123122
or not (mask.shape[0] == 1 or mask.shape[0] == x.shape[0])
124123
or not (
125124
(mask.shape[1] == x.shape[1] and mask.shape[2] == x.shape[2])
@@ -151,7 +150,12 @@ def generate( # pylint: disable=W0221
151150
self.target_label = target_label
152151

153152
patched_images, transforms = self._augment_images_with_patch(
154-
x, self._patch, random_location=True, channels_first=self.estimator.channels_first, mask=mask
153+
x,
154+
self._patch,
155+
random_location=True,
156+
channels_first=self.estimator.channels_first,
157+
mask=mask,
158+
transforms=None,
155159
)
156160
patch_target: List[Dict[str, np.ndarray]] = list()
157161

@@ -237,6 +241,15 @@ def generate( # pylint: disable=W0221
237241
a_max=self.estimator.clip_values[1],
238242
)
239243

244+
patched_images, _ = self._augment_images_with_patch(
245+
x,
246+
self._patch,
247+
random_location=False,
248+
channels_first=self.estimator.channels_first,
249+
mask=None,
250+
transforms=transforms,
251+
)
252+
240253
return self._patch
241254

242255
@staticmethod
@@ -246,6 +259,7 @@ def _augment_images_with_patch(
246259
random_location: bool,
247260
channels_first: bool,
248261
mask: Optional[np.ndarray] = None,
262+
transforms: List[Dict[str, int]] = None,
249263
) -> Tuple[np.ndarray, List[Dict[str, int]]]:
250264
"""
251265
Augment images with patch.
@@ -258,9 +272,16 @@ def _augment_images_with_patch(
258272
:param mask: An boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
259273
(N, H, W) without their channel dimensions. Any features for which the mask is True can be the
260274
center location of the patch during sampling.
275+
:param transforms: Patch transforms, requires `random_location=False`, and `mask=None`.
261276
:type mask: `np.ndarray`
262277
"""
263-
transformations = list()
278+
if transforms is not None:
279+
if random_location or mask is not None:
280+
raise ValueError(
281+
"Definition of patch locations in `locations` requires `random_location=False`, and `mask=None`."
282+
)
283+
284+
random_transformations = list()
264285
x_copy = x.copy()
265286
patch_copy = patch.copy()
266287

@@ -270,48 +291,56 @@ def _augment_images_with_patch(
270291

271292
for i_image in range(x.shape[0]):
272293

273-
if random_location:
274-
if mask is None:
275-
i_x_1 = random.randint(0, x_copy.shape[1] - 1 - patch_copy.shape[0])
276-
i_y_1 = random.randint(0, x_copy.shape[2] - 1 - patch_copy.shape[1])
277-
else:
294+
if transforms is None:
278295

279-
if mask.shape[0] == 1:
280-
mask_2d = mask[0, :, :]
296+
if random_location:
297+
if mask is None:
298+
i_x_1 = random.randint(0, x_copy.shape[1] - 1 - patch_copy.shape[0])
299+
i_y_1 = random.randint(0, x_copy.shape[2] - 1 - patch_copy.shape[1])
281300
else:
282-
mask_2d = mask[i_image, :, :]
283301

284-
edge_x_0 = patch_copy.shape[0] // 2
285-
edge_x_1 = patch_copy.shape[0] - edge_x_0
286-
edge_y_0 = patch_copy.shape[1] // 2
287-
edge_y_1 = patch_copy.shape[1] - edge_y_0
302+
if mask.shape[0] == 1:
303+
mask_2d = mask[0, :, :]
304+
else:
305+
mask_2d = mask[i_image, :, :]
288306

289-
mask_2d[0:edge_x_0, :] = False
290-
mask_2d[-edge_x_1:, :] = False
291-
mask_2d[:, 0:edge_y_0] = False
292-
mask_2d[:, -edge_y_1:] = False
307+
edge_x_0 = patch_copy.shape[0] // 2
308+
edge_x_1 = patch_copy.shape[0] - edge_x_0
309+
edge_y_0 = patch_copy.shape[1] // 2
310+
edge_y_1 = patch_copy.shape[1] - edge_y_0
293311

294-
num_pos = np.argwhere(mask_2d).shape[0]
295-
pos_id = np.random.choice(num_pos, size=1)
296-
pos = np.argwhere(mask_2d > 0)[pos_id[0]]
297-
i_x_1 = pos[0] - edge_x_0
298-
i_y_1 = pos[1] - edge_y_0
312+
mask_2d[0:edge_x_0, :] = False
313+
mask_2d[-edge_x_1:, :] = False
314+
mask_2d[:, 0:edge_y_0] = False
315+
mask_2d[:, -edge_y_1:] = False
299316

300-
else:
301-
i_x_1 = 0
302-
i_y_1 = 0
317+
num_pos = np.argwhere(mask_2d).shape[0]
318+
pos_id = np.random.choice(num_pos, size=1)
319+
pos = np.argwhere(mask_2d > 0)[pos_id[0]]
320+
i_x_1 = pos[0] - edge_x_0
321+
i_y_1 = pos[1] - edge_y_0
303322

304-
i_x_2 = i_x_1 + patch_copy.shape[0]
305-
i_y_2 = i_y_1 + patch_copy.shape[1]
323+
else:
324+
i_x_1 = 0
325+
i_y_1 = 0
326+
327+
i_x_2 = i_x_1 + patch_copy.shape[0]
328+
i_y_2 = i_y_1 + patch_copy.shape[1]
306329

307-
transformations.append({"i_x_1": i_x_1, "i_y_1": i_y_1, "i_x_2": i_x_2, "i_y_2": i_y_2})
330+
random_transformations.append({"i_x_1": i_x_1, "i_y_1": i_y_1, "i_x_2": i_x_2, "i_y_2": i_y_2})
331+
332+
else:
333+
i_x_1 = transforms[i_image]["i_x_1"]
334+
i_x_2 = transforms[i_image]["i_x_2"]
335+
i_y_1 = transforms[i_image]["i_y_1"]
336+
i_y_2 = transforms[i_image]["i_y_2"]
308337

309338
x_copy[i_image, i_x_1:i_x_2, i_y_1:i_y_2, :] = patch_copy
310339

311340
if channels_first:
312341
x_copy = np.transpose(x_copy, (0, 3, 1, 2))
313342

314-
return x_copy, transformations
343+
return x_copy, random_transformations
315344

316345
def apply_patch(
317346
self,

0 commit comments

Comments
 (0)