2323from __future__ import absolute_import , division , print_function , unicode_literals
2424
2525import logging
26- from typing import Dict , List , Tuple
26+ from typing import Dict , List , Tuple , Union
2727
2828import numpy as np
2929from tqdm .auto import tqdm
@@ -85,35 +85,39 @@ def __init__(
8585
8686 def poison ( # pylint: disable=W0221
8787 self ,
88- x : np .ndarray ,
88+ x : Union [ np .ndarray , List [ np . ndarray ]] ,
8989 y : List [Dict [str , np .ndarray ]],
9090 ** kwargs ,
91- ) -> Tuple [np .ndarray , List [Dict [str , np .ndarray ]]]:
91+ ) -> Tuple [Union [ np .ndarray , List [ np . ndarray ]] , List [Dict [str , np .ndarray ]]]:
9292 """
9393 Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
9494 for labels `y`.
9595
96- :param x: Sample images of shape `NCHW` or `NHWC`.
96+ :param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size .
9797 :param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9898 of the dictionary are:
99+
99100 - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
100101 - labels [N]: the labels for each image.
101- - scores [N]: the scores or each prediction.
102102 :return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
103103 """
104- x_ndim = len (x .shape )
104+ if isinstance (x , np .ndarray ):
105+ x_ndim = len (x .shape )
106+ else :
107+ x_ndim = len (x [0 ].shape ) + 1
105108
106109 if x_ndim != 4 :
107110 raise ValueError ("Unrecognized input dimension. BadDet OGA can only be applied to image data." )
108111
109- if self . channels_first :
110- # NCHW --> NHWC
111- x = np . transpose (x , ( 0 , 2 , 3 , 1 ))
112-
113- x_poison = x . copy ()
114- y_poison : List [ Dict [ str , np . ndarray ]] = []
112+ # copy images
113+ x_poison : Union [ np . ndarray , List [ np . ndarray ]]
114+ if isinstance (x , np . ndarray ):
115+ x_poison = x . copy ()
116+ else :
117+ x_poison = [x_i . copy () for x_i in x ]
115118
116119 # copy labels
120+ y_poison : List [Dict [str , np .ndarray ]] = []
117121 for y_i in y :
118122 target_dict = {k : v .copy () for k , v in y_i .items ()}
119123 y_poison .append (target_dict )
@@ -123,14 +127,15 @@ def poison( # pylint: disable=W0221
123127 num_poison = int (self .percent_poison * len (all_indices ))
124128 selected_indices = np .random .choice (all_indices , num_poison , replace = False )
125129
126- _ , height , width , _ = x_poison .shape
127-
128130 for i in tqdm (selected_indices , desc = "BadDet OGA iteration" , disable = not self .verbose ):
129131 image = x_poison [i ]
130-
131132 boxes = y_poison [i ]["boxes" ]
132133 labels = y_poison [i ]["labels" ]
133134
135+ if self .channels_first :
136+ image = np .transpose (image , (1 , 2 , 0 ))
137+ height , width , _ = image .shape
138+
134139 # generate the fake bounding box
135140 y_1 = np .random .randint (0 , height - self .bbox_height )
136141 x_1 = np .random .randint (0 , width - self .bbox_width )
@@ -145,6 +150,11 @@ def poison( # pylint: disable=W0221
145150 poisoned_input , _ = self .backdoor .poison (bounding_box [np .newaxis ], labels )
146151 image [y_1 :y_2 , x_1 :x_2 , :] = poisoned_input [0 ]
147152
153+ # replace the original image with the poisoned image
154+ if self .channels_first :
155+ image = np .transpose (image , (2 , 0 , 1 ))
156+ x_poison [i ] = image
157+
148158 # insert the fake bounding box and label
149159 y_poison [i ]["boxes" ] = np .concatenate ((boxes , [[x_1 , y_1 , x_2 , y_2 ]]))
150160 y_poison [i ]["labels" ] = np .concatenate ((labels , [self .class_target ]))
@@ -155,10 +165,6 @@ def poison( # pylint: disable=W0221
155165 mask [y_1 :y_2 , x_1 :x_2 , :] = 1
156166 y_poison [i ]["masks" ] = np .concatenate ((y_poison [i ]["masks" ], [mask ]))
157167
158- if self .channels_first :
159- # NHWC --> NCHW
160- x_poison = np .transpose (x_poison , (0 , 3 , 1 , 2 ))
161-
162168 return x_poison , y_poison
163169
164170 def _check_params (self ) -> None :
0 commit comments