Skip to content

Commit 089b8d0

Browse files
authored
Merge pull request #599 from Trusted-AI/development_dpatch
Add targeted version of DPatch
2 parents 79cce25 + f399663 commit 089b8d0

File tree

3 files changed

+96
-23
lines changed

3 files changed

+96
-23
lines changed

art/attacks/evasion/dpatch.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import math
2525
import random
26-
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
26+
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
2727

2828
import numpy as np
2929
from tqdm import trange
@@ -78,15 +78,26 @@ def __init__(
7878
self.learning_rate = learning_rate
7979
self.max_iter = max_iter
8080
self.batch_size = batch_size
81-
self._patch = np.ones(shape=patch_shape) * (self.estimator.clip_values[1] + self.estimator.clip_values[0]) / 2.0
81+
self._patch = np.random.randint(
82+
self.estimator.clip_values[0], self.estimator.clip_values[1], size=patch_shape
83+
).astype(np.float32)
8284
self._check_params()
8385

84-
def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
86+
self.target_label = []
87+
88+
def generate(
89+
self,
90+
x: np.ndarray,
91+
y: Optional[np.ndarray] = None,
92+
target_label: Optional[Union[int, List[int], np.ndarray]] = None,
93+
**kwargs
94+
) -> np.ndarray:
8595
"""
8696
Generate DPatch.
8797
8898
:param x: Sample images.
8999
:param y: Target labels for object detector.
100+
:param target_label: The target label of the DPatch attack.
90101
:return: Adversarial patch.
91102
"""
92103
channel_index = 1 if self.estimator.channels_first else x.ndim - 1
@@ -96,6 +107,17 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
96107
raise ValueError("The DPatch attack does not use target labels.")
97108
if x.ndim != 4:
98109
raise ValueError("The adversarial patch can only be applied to images.")
110+
if target_label is not None:
111+
if isinstance(target_label, int):
112+
self.target_label = [target_label] * x.shape[0]
113+
elif isinstance(target_label, np.ndarray):
114+
if not (target_label.shape == (x.shape[0], 1) or target_label.shape == (x.shape[0],)):
115+
raise ValueError("The target_label has to be a 1-dimensional array.")
116+
self.target_label = target_label.tolist()
117+
else:
118+
if not len(target_label) == x.shape[0] or not isinstance(target_label, list):
119+
raise ValueError("The target_label as list of integers needs to of length number of images in `x`.")
120+
self.target_label = target_label
99121

100122
for i_step in trange(self.max_iter, desc="DPatch iteration"):
101123
if i_step == 0 or (i_step + 1) % 100 == 0:
@@ -106,19 +128,32 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
106128
)
107129
patch_target: List[Dict[str, np.ndarray]] = list()
108130

109-
for i_image in range(patched_images.shape[0]):
131+
if self.target_label:
132+
133+
for i_image in range(patched_images.shape[0]):
134+
i_x_1 = transforms[i_image]["i_x_1"]
135+
i_x_2 = transforms[i_image]["i_x_2"]
136+
i_y_1 = transforms[i_image]["i_y_1"]
137+
i_y_2 = transforms[i_image]["i_y_2"]
110138

111-
i_x_1 = transforms[i_image]["i_x_1"]
112-
i_x_2 = transforms[i_image]["i_x_2"]
113-
i_y_1 = transforms[i_image]["i_y_1"]
114-
i_y_2 = transforms[i_image]["i_y_2"]
139+
target_dict = dict()
140+
target_dict["boxes"] = np.asarray([[i_x_1, i_y_1, i_x_2, i_y_2]])
141+
target_dict["labels"] = np.asarray([self.target_label[i_image],])
142+
target_dict["scores"] = np.asarray([1.0,])
115143

116-
target_dict = dict()
117-
target_dict["boxes"] = np.asarray([[i_x_1, i_y_1, i_x_2, i_y_2]])
118-
target_dict["labels"] = np.asarray([1,])
119-
target_dict["scores"] = np.asarray([1.0,])
144+
patch_target.append(target_dict)
120145

121-
patch_target.append(target_dict)
146+
else:
147+
148+
predictions = self.estimator.predict(x=patched_images)
149+
150+
for i_image in range(patched_images.shape[0]):
151+
target_dict = dict()
152+
target_dict["boxes"] = predictions[i_image]["boxes"].detach().cpu().numpy()
153+
target_dict["labels"] = predictions[i_image]["labels"].detach().cpu().numpy()
154+
target_dict["scores"] = predictions[i_image]["scores"].detach().cpu().numpy()
155+
156+
patch_target.append(target_dict)
122157

123158
num_batches = math.ceil(x.shape[0] / self.batch_size)
124159
patch_gradients = np.zeros_like(self._patch)
@@ -131,7 +166,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
131166
x=patched_images[i_batch_start:i_batch_end], y=patch_target[i_batch_start:i_batch_end],
132167
)
133168

134-
for i_image in range(self.batch_size):
169+
for i_image in range(patched_images.shape[0]):
135170

136171
i_x_1 = transforms[i_batch_start + i_image]["i_x_1"]
137172
i_x_2 = transforms[i_batch_start + i_image]["i_x_2"]
@@ -143,9 +178,13 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
143178
else:
144179
patch_gradients_i = gradients[i_image, i_x_1:i_x_2, i_y_1:i_y_2, :]
145180

146-
patch_gradients += patch_gradients_i
181+
patch_gradients = patch_gradients + patch_gradients_i
182+
183+
if self.target_label:
184+
self._patch = self._patch - np.sign(patch_gradients) * self.learning_rate
185+
else:
186+
self._patch = self._patch + np.sign(patch_gradients) * self.learning_rate
147187

148-
self._patch -= patch_gradients * self.learning_rate
149188
self._patch = np.clip(
150189
self._patch, a_min=self.estimator.clip_values[0], a_max=self.estimator.clip_values[1],
151190
)

art/estimators/object_detection/pytorch_faster_rcnn.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
This module implements the task specific estimator for Faster R-CNN v3 in PyTorch.
2020
"""
2121
import logging
22-
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
22+
from typing import List, Dict, Optional, Tuple, Union, TYPE_CHECKING
2323

2424
import numpy as np
2525

@@ -29,6 +29,7 @@
2929

3030
if TYPE_CHECKING:
3131
# pylint: disable=C0412
32+
import torch
3233
import torchvision
3334

3435
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
@@ -134,7 +135,7 @@ def __init__(
134135
self._model.eval()
135136
self.attack_losses: Tuple[str, ...] = attack_losses
136137

137-
def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
138+
def loss_gradient(self, x: np.ndarray, y: List[Dict[str, np.ndarray]], **kwargs) -> np.ndarray:
138139
"""
139140
Compute the gradient of the loss function w.r.t. `x`.
140141
@@ -158,9 +159,9 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
158159

159160
if y is not None:
160161
for i, y_i in enumerate(y):
161-
y[i]["boxes"] = torch.tensor(y_i["boxes"], dtype=torch.float).to(self._device)
162-
y[i]["labels"] = torch.tensor(y_i["labels"], dtype=torch.int64).to(self._device)
163-
y[i]["scores"] = torch.tensor(y_i["scores"]).to(self._device)
162+
y[i]["boxes"] = torch.from_numpy(y_i["boxes"]).type(torch.float).to(self._device)
163+
y[i]["labels"] = torch.from_numpy(y_i["labels"]).type(torch.int64).to(self._device)
164+
y[i]["scores"] = torch.from_numpy(y_i["scores"]).to(self._device)
164165

165166
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
166167
image_tensor_list = list()
@@ -207,13 +208,13 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
207208

208209
return grads
209210

210-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
211+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[str, "torch.Tensor"]]:
211212
"""
212213
Perform prediction for a batch of inputs.
213214
214215
:param x: Samples of shape (nb_samples, height, width, nb_channels).
215216
:param batch_size: Batch size.
216-
:return: Predictions of format `List[Dict[Tensor]]`, one for each input image. The
217+
:return: Predictions of format `List[Dict[str, Tensor]]`, one for each input image. The
217218
fields of the Dict are as follows:
218219
219220
- boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values \

tests/attacks/evasion/test_dpatch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,39 @@ def test_augment_images_with_patch(random_location, image_format, fix_get_mnist_
130130
np.testing.assert_array_equal(patched_images[1, 2, :, 0], patched_images_column)
131131

132132

133+
def test_exceptions(get_default_mnist_subset, image_dl_estimator):
134+
class ObjectDetector(BaseEstimator, LossGradientsMixin, ObjectDetectorMixin):
135+
136+
clip_values = (0, 1)
137+
channels_first = False
138+
139+
def fit(self):
140+
pass
141+
142+
def loss_gradient(self, x, y, **kwargs):
143+
pass
144+
145+
def predict(self, x, **kwargs):
146+
pass
147+
148+
estimator = ObjectDetector()
149+
150+
(x_train_mnist, y_train_mnist), (_, _) = get_default_mnist_subset
151+
152+
attack = DPatch(estimator=estimator, patch_shape=(4, 4, 1), learning_rate=5.0, max_iter=5, batch_size=16,)
153+
154+
with pytest.raises(ValueError, match="The DPatch attack does not use target labels."):
155+
attack.generate(x=x_train_mnist, y=y_train_mnist)
156+
157+
with pytest.raises(
158+
ValueError, match="The target_label as list of integers needs to of length number of images in" " `x`."
159+
):
160+
attack.generate(x=x_train_mnist, y=None, target_label=[1, 2, 3])
161+
162+
with pytest.raises(ValueError, match="The target_label has to be a 1-dimensional array."):
163+
attack.generate(x=x_train_mnist, y=None, target_label=np.asarray([[1, 2, 3], [4, 5, 6]]))
164+
165+
133166
def test_classifier_type_check_fail():
134167
backend_test_classifier_type_check_fail(DPatch, [BaseEstimator, LossGradientsMixin, ObjectDetectorMixin])
135168

0 commit comments

Comments
 (0)