Skip to content

Commit f7eaa37

Browse files
committed
extend bad det oga for arbitrary sizes
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 3e9defa commit f7eaa37

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

art/attacks/poisoning/bad_det/bad_det_oga.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Dict, List, Tuple
26+
from typing import Dict, List, Tuple, Union
2727

2828
import numpy as np
2929
from tqdm.auto import tqdm
@@ -85,35 +85,39 @@ def __init__(
8585

8686
def poison( # pylint: disable=W0221
8787
self,
88-
x: np.ndarray,
88+
x: Union[np.ndarray, List[np.ndarray]],
8989
y: List[Dict[str, np.ndarray]],
9090
**kwargs,
91-
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
91+
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
9292
"""
9393
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
9494
for labels `y`.
9595
96-
:param x: Sample images of shape `NCHW` or `NHWC`.
96+
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
9797
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9898
of the dictionary are:
99+
99100
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
100101
- labels [N]: the labels for each image.
101-
- scores [N]: the scores or each prediction.
102102
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
103103
"""
104-
x_ndim = len(x.shape)
104+
if isinstance(x, np.ndarray):
105+
x_ndim = len(x.shape)
106+
else:
107+
x_ndim = len(x[0].shape) + 1
105108

106109
if x_ndim != 4:
107110
raise ValueError("Unrecognized input dimension. BadDet OGA can only be applied to image data.")
108111

109-
if self.channels_first:
110-
# NCHW --> NHWC
111-
x = np.transpose(x, (0, 2, 3, 1))
112-
113-
x_poison = x.copy()
114-
y_poison: List[Dict[str, np.ndarray]] = []
112+
# copy images
113+
x_poison: Union[np.ndarray, List[np.ndarray]]
114+
if isinstance(x, np.ndarray):
115+
x_poison = x.copy()
116+
else:
117+
x_poison = [x_i.copy() for x_i in x]
115118

116119
# copy labels
120+
y_poison: List[Dict[str, np.ndarray]] = []
117121
for y_i in y:
118122
target_dict = {k: v.copy() for k, v in y_i.items()}
119123
y_poison.append(target_dict)
@@ -123,14 +127,15 @@ def poison( # pylint: disable=W0221
123127
num_poison = int(self.percent_poison * len(all_indices))
124128
selected_indices = np.random.choice(all_indices, num_poison, replace=False)
125129

126-
_, height, width, _ = x_poison.shape
127-
128130
for i in tqdm(selected_indices, desc="BadDet OGA iteration", disable=not self.verbose):
129131
image = x_poison[i]
130-
131132
boxes = y_poison[i]["boxes"]
132133
labels = y_poison[i]["labels"]
133134

135+
if self.channels_first:
136+
image = np.transpose(image, (1, 2, 0))
137+
height, width, _ = image.shape
138+
134139
# generate the fake bounding box
135140
y_1 = np.random.randint(0, height - self.bbox_height)
136141
x_1 = np.random.randint(0, width - self.bbox_width)
@@ -145,6 +150,11 @@ def poison( # pylint: disable=W0221
145150
poisoned_input, _ = self.backdoor.poison(bounding_box[np.newaxis], labels)
146151
image[y_1:y_2, x_1:x_2, :] = poisoned_input[0]
147152

153+
# replace the original image with the poisoned image
154+
if self.channels_first:
155+
image = np.transpose(image, (2, 0, 1))
156+
x_poison[i] = image
157+
148158
# insert the fake bounding box and label
149159
y_poison[i]["boxes"] = np.concatenate((boxes, [[x_1, y_1, x_2, y_2]]))
150160
y_poison[i]["labels"] = np.concatenate((labels, [self.class_target]))
@@ -155,10 +165,6 @@ def poison( # pylint: disable=W0221
155165
mask[y_1:y_2, x_1:x_2, :] = 1
156166
y_poison[i]["masks"] = np.concatenate((y_poison[i]["masks"], [mask]))
157167

158-
if self.channels_first:
159-
# NHWC --> NCHW
160-
x_poison = np.transpose(x_poison, (0, 3, 1, 2))
161-
162168
return x_poison, y_poison
163169

164170
def _check_params(self) -> None:

0 commit comments

Comments
 (0)