Skip to content

Commit 3e9defa

Browse files
committed
extend bad det gma for arbitrary sizes
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 76ae22a commit 3e9defa

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

art/attacks/attack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,10 @@ def __init__(self):
339339
@abc.abstractmethod
340340
def poison(
341341
self,
342-
x: np.ndarray,
342+
x: Union[np.ndarray, List[np.ndarray]],
343343
y: List[Dict[str, np.ndarray]],
344344
**kwargs,
345-
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
345+
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
346346
"""
347347
Generate poisoning examples and return them as an array. This method should be overridden by all concrete
348348
poisoning attack implementations.

art/attacks/poisoning/bad_det/bad_det_gma.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Dict, List, Tuple
26+
from typing import Dict, List, Tuple, Union
2727

2828
import numpy as np
2929
from tqdm.auto import tqdm
@@ -77,36 +77,39 @@ def __init__(
7777

7878
def poison( # pylint: disable=W0221
7979
self,
80-
x: np.ndarray,
80+
x: Union[np.ndarray, List[np.ndarray]],
8181
y: List[Dict[str, np.ndarray]],
8282
**kwargs,
83-
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
83+
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
8484
"""
8585
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
8686
for labels `y`.
8787
88-
:param x: Sample images of shape `NCHW` or `NHWC`.
88+
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
8989
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9090
of the dictionary are:
9191
9292
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
9393
- labels [N]: the labels for each image.
94-
- scores [N]: the scores or each prediction.
9594
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
9695
"""
97-
x_ndim = len(x.shape)
96+
if isinstance(x, np.ndarray):
97+
x_ndim = len(x.shape)
98+
else:
99+
x_ndim = len(x[0].shape) + 1
98100

99101
if x_ndim != 4:
100102
raise ValueError("Unrecognized input dimension. BadDet GMA can only be applied to image data.")
101103

102-
if self.channels_first:
103-
# NCHW --> NHWC
104-
x = np.transpose(x, (0, 2, 3, 1))
105-
106-
x_poison = x.copy()
107-
y_poison: List[Dict[str, np.ndarray]] = []
104+
# copy images
105+
x_poison: Union[np.ndarray, List[np.ndarray]]
106+
if isinstance(x, np.ndarray):
107+
x_poison = x.copy()
108+
else:
109+
x_poison = [x_i.copy() for x_i in x]
108110

109111
# copy labels
112+
y_poison: List[Dict[str, np.ndarray]] = []
110113
for y_i in y:
111114
target_dict = {k: v.copy() for k, v in y_i.items()}
112115
y_poison.append(target_dict)
@@ -120,18 +123,22 @@ def poison( # pylint: disable=W0221
120123
image = x_poison[i]
121124
labels = y_poison[i]["labels"]
122125

126+
if self.channels_first:
127+
image = np.transpose(image, (1, 2, 0))
128+
123129
# insert backdoor into the image
124130
# add an additional dimension to create a batch of size 1
125131
poisoned_input, _ = self.backdoor.poison(image[np.newaxis], labels)
126132
x_poison[i] = poisoned_input[0]
127133

134+
# replace the original image with the poisoned image
135+
if self.channels_first:
136+
image = np.transpose(image, (2, 0, 1))
137+
x_poison[i] = image
138+
128139
# change all labels to the target label
129140
y_poison[i]["labels"] = np.full(labels.shape, self.class_target)
130141

131-
if self.channels_first:
132-
# NHWC --> NCHW
133-
x_poison = np.transpose(x_poison, (0, 3, 1, 2))
134-
135142
return x_poison, y_poison
136143

137144
def _check_params(self) -> None:

0 commit comments

Comments
 (0)