2323from __future__ import absolute_import , division , print_function , unicode_literals
2424
2525import logging
26- from typing import Dict , List , Tuple
26+ from typing import Dict , List , Tuple , Union
2727
2828import numpy as np
2929from tqdm .auto import tqdm
@@ -77,36 +77,39 @@ def __init__(
7777
7878 def poison ( # pylint: disable=W0221
7979 self ,
80- x : np .ndarray ,
80+ x : Union [ np .ndarray , List [ np . ndarray ]] ,
8181 y : List [Dict [str , np .ndarray ]],
8282 ** kwargs ,
83- ) -> Tuple [np .ndarray , List [Dict [str , np .ndarray ]]]:
83+ ) -> Tuple [Union [ np .ndarray , List [ np . ndarray ]] , List [Dict [str , np .ndarray ]]]:
8484 """
8585 Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
8686 for labels `y`.
8787
88- :param x: Sample images of shape `NCHW` or `NHWC`.
88+ :param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size .
8989 :param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9090 of the dictionary are:
9191
9292 - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
9393 - labels [N]: the labels for each image.
94- - scores [N]: the scores or each prediction.
9594 :return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
9695 """
97- x_ndim = len (x .shape )
96+ if isinstance (x , np .ndarray ):
97+ x_ndim = len (x .shape )
98+ else :
99+ x_ndim = len (x [0 ].shape ) + 1
98100
99101 if x_ndim != 4 :
100102 raise ValueError ("Unrecognized input dimension. BadDet ODA can only be applied to image data." )
101103
102- if self . channels_first :
103- # NCHW --> NHWC
104- x = np . transpose (x , ( 0 , 2 , 3 , 1 ))
105-
106- x_poison = x . copy ()
107- y_poison : List [ Dict [ str , np . ndarray ]] = []
104+ # copy images
105+ x_poison : Union [ np . ndarray , List [ np . ndarray ]]
106+ if isinstance (x , np . ndarray ):
107+ x_poison = x . copy ()
108+ else :
109+ x_poison = [x_i . copy () for x_i in x ]
108110
109111 # copy labels and find indices of the source class
112+ y_poison : List [Dict [str , np .ndarray ]] = []
110113 source_indices = []
111114 for i , y_i in enumerate (y ):
112115 target_dict = {k : v .copy () for k , v in y_i .items ()}
@@ -121,10 +124,12 @@ def poison( # pylint: disable=W0221
121124
122125 for i in tqdm (selected_indices , desc = "BadDet ODA iteration" , disable = not self .verbose ):
123126 image = x_poison [i ]
124-
125127 boxes = y_poison [i ]["boxes" ]
126128 labels = y_poison [i ]["labels" ]
127129
130+ if self .channels_first :
131+ image = np .transpose (image , (1 , 2 , 0 ))
132+
128133 keep_indices = []
129134
130135 for j , (box , label ) in enumerate (zip (boxes , labels )):
@@ -140,13 +145,14 @@ def poison( # pylint: disable=W0221
140145 else :
141146 keep_indices .append (j )
142147
148+ # replace the original image with the poisoned image
149+ if self .channels_first :
150+ image = np .transpose (image , (2 , 0 , 1 ))
151+ x_poison [i ] = image
152+
143153 # remove labels for poisoned bounding boxes
144154 y_poison [i ] = {k : v [keep_indices ] for k , v in y_poison [i ].items ()}
145155
146- if self .channels_first :
147- # NHWC --> NCHW
148- x_poison = np .transpose (x_poison , (0 , 3 , 1 , 2 ))
149-
150156 return x_poison , y_poison
151157
152158 def _check_params (self ) -> None :
0 commit comments