2323from __future__ import absolute_import , division , print_function , unicode_literals
2424
2525import logging
26- from typing import Dict , List , Tuple
26+ from typing import Dict , List , Tuple , Union
2727
2828import numpy as np
2929from tqdm .auto import tqdm
@@ -77,36 +77,39 @@ def __init__(
7777
7878 def poison ( # pylint: disable=W0221
7979 self ,
80- x : np .ndarray ,
80+ x : Union [ np .ndarray , List [ np . ndarray ]] ,
8181 y : List [Dict [str , np .ndarray ]],
8282 ** kwargs ,
83- ) -> Tuple [np .ndarray , List [Dict [str , np .ndarray ]]]:
83+ ) -> Tuple [Union [ np .ndarray , List [ np . ndarray ]] , List [Dict [str , np .ndarray ]]]:
8484 """
8585 Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
8686 for labels `y`.
8787
88- :param x: Sample images of shape `NCHW` or `NHWC`.
88+ :param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size .
8989 :param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
9090 of the dictionary are:
9191
9292 - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
9393 - labels [N]: the labels for each image.
94- - scores [N]: the scores or each prediction.
9594 :return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
9695 """
97- x_ndim = len (x .shape )
96+ if isinstance (x , np .ndarray ):
97+ x_ndim = len (x .shape )
98+ else :
99+ x_ndim = len (x [0 ].shape ) + 1
98100
99101 if x_ndim != 4 :
100102 raise ValueError ("Unrecognized input dimension. BadDet GMA can only be applied to image data." )
101103
102- if self . channels_first :
103- # NCHW --> NHWC
104- x = np . transpose (x , ( 0 , 2 , 3 , 1 ))
105-
106- x_poison = x . copy ()
107- y_poison : List [ Dict [ str , np . ndarray ]] = []
104+ # copy images
105+ x_poison : Union [ np . ndarray , List [ np . ndarray ]]
106+ if isinstance (x , np . ndarray ):
107+ x_poison = x . copy ()
108+ else :
109+ x_poison = [x_i . copy () for x_i in x ]
108110
109111 # copy labels
112+ y_poison : List [Dict [str , np .ndarray ]] = []
110113 for y_i in y :
111114 target_dict = {k : v .copy () for k , v in y_i .items ()}
112115 y_poison .append (target_dict )
@@ -120,18 +123,22 @@ def poison( # pylint: disable=W0221
120123 image = x_poison [i ]
121124 labels = y_poison [i ]["labels" ]
122125
126+ if self .channels_first :
127+ image = np .transpose (image , (1 , 2 , 0 ))
128+
123129 # insert backdoor into the image
124130 # add an additional dimension to create a batch of size 1
125131 poisoned_input , _ = self .backdoor .poison (image [np .newaxis ], labels )
126132 x_poison [i ] = poisoned_input [0 ]
127133
134+ # replace the original image with the poisoned image
135+ if self .channels_first :
136+ image = np .transpose (image , (2 , 0 , 1 ))
137+ x_poison [i ] = image
138+
128139 # change all labels to the target label
129140 y_poison [i ]["labels" ] = np .full (labels .shape , self .class_target )
130141
131- if self .channels_first :
132- # NHWC --> NCHW
133- x_poison = np .transpose (x_poison , (0 , 3 , 1 , 2 ))
134-
135142 return x_poison , y_poison
136143
137144 def _check_params (self ) -> None :
0 commit comments