Skip to content

Commit 4ee410b

Browse files
committed
extend bad det rma for arbitrary sizes
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent f7eaa37 commit 4ee410b

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

art/attacks/poisoning/bad_det/bad_det_rma.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Dict, List, Tuple, Optional
26+
from typing import Dict, List, Tuple, Union, Optional
2727

2828
import numpy as np
2929
from tqdm.auto import tqdm
@@ -82,36 +82,39 @@ def __init__(
8282

8383
def poison( # pylint: disable=W0221
8484
self,
85-
x: np.ndarray,
85+
x: Union[np.ndarray, List[np.ndarray]],
8686
y: List[Dict[str, np.ndarray]],
8787
**kwargs,
88-
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
88+
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
8989
"""
9090
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
9191
for labels `y`.
9292
93-
:param x: Sample images of shape `NCHW` or `NHWC`.
93+
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
9494
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9595
of the dictionary are:
9696
9797
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
9898
- labels [N]: the labels for each image.
99-
- scores [N]: the scores or each prediction.
10099
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
101100
"""
102-
x_ndim = len(x.shape)
101+
if isinstance(x, np.ndarray):
102+
x_ndim = len(x.shape)
103+
else:
104+
x_ndim = len(x[0].shape) + 1
103105

104106
if x_ndim != 4:
105107
raise ValueError("Unrecognized input dimension. BadDet RMA can only be applied to image data.")
106108

107-
if self.channels_first:
108-
# NCHW --> NHWC
109-
x = np.transpose(x, (0, 2, 3, 1))
110-
111-
x_poison = x.copy()
112-
y_poison: List[Dict[str, np.ndarray]] = []
109+
# copy images
110+
x_poison: Union[np.ndarray, List[np.ndarray]]
111+
if isinstance(x, np.ndarray):
112+
x_poison = x.copy()
113+
else:
114+
x_poison = [x_i.copy() for x_i in x]
113115

114116
# copy labels and find indices of the source class
117+
y_poison: List[Dict[str, np.ndarray]] = []
115118
source_indices = []
116119
for i, y_i in enumerate(y):
117120
target_dict = {k: v.copy() for k, v in y_i.items()}
@@ -126,10 +129,12 @@ def poison( # pylint: disable=W0221
126129

127130
for i in tqdm(selected_indices, desc="BadDet RMA iteration", disable=not self.verbose):
128131
image = x_poison[i]
129-
130132
boxes = y_poison[i]["boxes"]
131133
labels = y_poison[i]["labels"]
132134

135+
if self.channels_first:
136+
image = np.transpose(image, (1, 2, 0))
137+
133138
for j, (box, label) in enumerate(zip(boxes, labels)):
134139
if self.class_source is None or label == self.class_source:
135140
# extract the bounding box from the image
@@ -144,9 +149,10 @@ def poison( # pylint: disable=W0221
144149
# change the source label to the target label
145150
labels[j] = self.class_target
146151

147-
if self.channels_first:
148-
# NHWC --> NCHW
149-
x_poison = np.transpose(x_poison, (0, 3, 1, 2))
152+
# replace the original image with the poisoned image
153+
if self.channels_first:
154+
image = np.transpose(image, (2, 0, 1))
155+
x_poison[i] = image
150156

151157
return x_poison, y_poison
152158

0 commit comments

Comments
 (0)