2323from __future__ import absolute_import , division , print_function , unicode_literals
2424
2525import logging
26- from typing import Dict , List , Tuple , Optional
26+ from typing import Dict , List , Tuple , Union , Optional
2727
2828import numpy as np
2929from tqdm .auto import tqdm
@@ -82,36 +82,39 @@ def __init__(
8282
8383 def poison ( # pylint: disable=W0221
8484 self ,
85- x : np .ndarray ,
85+ x : Union [ np .ndarray , List [ np . ndarray ]] ,
8686 y : List [Dict [str , np .ndarray ]],
8787 ** kwargs ,
88- ) -> Tuple [np .ndarray , List [Dict [str , np .ndarray ]]]:
88+ ) -> Tuple [Union [ np .ndarray , List [ np . ndarray ]] , List [Dict [str , np .ndarray ]]]:
8989 """
9090 Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
9191 for labels `y`.
9292
93- :param x: Sample images of shape `NCHW` or `NHWC`.
93+ :param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size .
9494 :param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9595 of the dictionary are:
9696
9797 - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
9898 - labels [N]: the labels for each image.
99- - scores [N]: the scores or each prediction.
10099 :return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
101100 """
102- x_ndim = len (x .shape )
101+ if isinstance (x , np .ndarray ):
102+ x_ndim = len (x .shape )
103+ else :
104+ x_ndim = len (x [0 ].shape ) + 1
103105
104106 if x_ndim != 4 :
105107 raise ValueError ("Unrecognized input dimension. BadDet RMA can only be applied to image data." )
106108
107- if self . channels_first :
108- # NCHW --> NHWC
109- x = np . transpose (x , ( 0 , 2 , 3 , 1 ))
110-
111- x_poison = x . copy ()
112- y_poison : List [ Dict [ str , np . ndarray ]] = []
109+ # copy images
110+ x_poison : Union [ np . ndarray , List [ np . ndarray ]]
111+ if isinstance (x , np . ndarray ):
112+ x_poison = x . copy ()
113+ else :
114+ x_poison = [x_i . copy () for x_i in x ]
113115
114116 # copy labels and find indices of the source class
117+ y_poison : List [Dict [str , np .ndarray ]] = []
115118 source_indices = []
116119 for i , y_i in enumerate (y ):
117120 target_dict = {k : v .copy () for k , v in y_i .items ()}
@@ -126,10 +129,12 @@ def poison( # pylint: disable=W0221
126129
127130 for i in tqdm (selected_indices , desc = "BadDet RMA iteration" , disable = not self .verbose ):
128131 image = x_poison [i ]
129-
130132 boxes = y_poison [i ]["boxes" ]
131133 labels = y_poison [i ]["labels" ]
132134
135+ if self .channels_first :
136+ image = np .transpose (image , (1 , 2 , 0 ))
137+
133138 for j , (box , label ) in enumerate (zip (boxes , labels )):
134139 if self .class_source is None or label == self .class_source :
135140 # extract the bounding box from the image
@@ -144,9 +149,10 @@ def poison( # pylint: disable=W0221
144149 # change the source label to the target label
145150 labels [j ] = self .class_target
146151
147- if self .channels_first :
148- # NHWC --> NCHW
149- x_poison = np .transpose (x_poison , (0 , 3 , 1 , 2 ))
152+ # replace the original image with the poisoned image
153+ if self .channels_first :
154+ image = np .transpose (image , (2 , 0 , 1 ))
155+ x_poison [i ] = image
150156
151157 return x_poison , y_poison
152158
0 commit comments