Skip to content

Commit a4f403c

Browse files
committed
extend bad det oda for arbitrary sizes
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 4ee410b commit a4f403c

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

art/attacks/poisoning/bad_det/bad_det_oda.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Dict, List, Tuple
26+
from typing import Dict, List, Tuple, Union
2727

2828
import numpy as np
2929
from tqdm.auto import tqdm
@@ -77,36 +77,39 @@ def __init__(
7777

7878
def poison( # pylint: disable=W0221
7979
self,
80-
x: np.ndarray,
80+
x: Union[np.ndarray, List[np.ndarray]],
8181
y: List[Dict[str, np.ndarray]],
8282
**kwargs,
83-
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
83+
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
8484
"""
8585
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
8686
for labels `y`.
8787
88-
:param x: Sample images of shape `NCHW` or `NHWC`.
88+
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
8989
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9090
of the dictionary are:
9191
9292
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
9393
- labels [N]: the labels for each image.
94-
- scores [N]: the scores or each prediction.
9594
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
9695
"""
97-
x_ndim = len(x.shape)
96+
if isinstance(x, np.ndarray):
97+
x_ndim = len(x.shape)
98+
else:
99+
x_ndim = len(x[0].shape) + 1
98100

99101
if x_ndim != 4:
100102
raise ValueError("Unrecognized input dimension. BadDet ODA can only be applied to image data.")
101103

102-
if self.channels_first:
103-
# NCHW --> NHWC
104-
x = np.transpose(x, (0, 2, 3, 1))
105-
106-
x_poison = x.copy()
107-
y_poison: List[Dict[str, np.ndarray]] = []
104+
# copy images
105+
x_poison: Union[np.ndarray, List[np.ndarray]]
106+
if isinstance(x, np.ndarray):
107+
x_poison = x.copy()
108+
else:
109+
x_poison = [x_i.copy() for x_i in x]
108110

109111
# copy labels and find indices of the source class
112+
y_poison: List[Dict[str, np.ndarray]] = []
110113
source_indices = []
111114
for i, y_i in enumerate(y):
112115
target_dict = {k: v.copy() for k, v in y_i.items()}
@@ -121,10 +124,12 @@ def poison( # pylint: disable=W0221
121124

122125
for i in tqdm(selected_indices, desc="BadDet ODA iteration", disable=not self.verbose):
123126
image = x_poison[i]
124-
125127
boxes = y_poison[i]["boxes"]
126128
labels = y_poison[i]["labels"]
127129

130+
if self.channels_first:
131+
image = np.transpose(image, (1, 2, 0))
132+
128133
keep_indices = []
129134

130135
for j, (box, label) in enumerate(zip(boxes, labels)):
@@ -140,13 +145,14 @@ def poison( # pylint: disable=W0221
140145
else:
141146
keep_indices.append(j)
142147

148+
# replace the original image with the poisoned image
149+
if self.channels_first:
150+
image = np.transpose(image, (2, 0, 1))
151+
x_poison[i] = image
152+
143153
# remove labels for poisoned bounding boxes
144154
y_poison[i] = {k: v[keep_indices] for k, v in y_poison[i].items()}
145155

146-
if self.channels_first:
147-
# NHWC --> NCHW
148-
x_poison = np.transpose(x_poison, (0, 3, 1, 2))
149-
150156
return x_poison, y_poison
151157

152158
def _check_params(self) -> None:

0 commit comments

Comments
 (0)