52
52
from sklearn .cluster import DBSCAN
53
53
from tqdm .auto import tqdm
54
54
55
- from art .utils import intersection_over_area , non_maximum_suppression
55
+ from art .utils import intersection_over_area
56
56
57
57
logger = logging .getLogger (__name__ )
58
58
@@ -94,68 +94,16 @@ def __init__(
94
94
self .epsilon = epsilon
95
95
self .verbose = verbose
96
96
97
- @property
98
97
@abc .abstractmethod
99
- def channels_first (self ) -> bool :
98
+ def _image_dimensions (self ) -> Tuple [ int , int ] :
100
99
"""
101
- :return: Boolean to indicate index of the color channels in the sample `x`.
102
- """
103
- pass
104
-
105
- @property
106
- @abc .abstractmethod
107
- def input_shape (self ) -> Tuple [int , ...]:
108
- """
109
- :return: Shape of one input sample.
110
- """
111
- pass
100
+ Get the height and width of a sample input image.
112
101
113
- @abc .abstractmethod
114
- def _predict_classifier (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> List [Dict [str , np .ndarray ]]:
115
- """
116
- Perform prediction for a batch of inputs.
117
-
118
- :param x: Samples of shape NCHW or NHWC.
119
- :param batch_size: Batch size.
120
- :return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
121
- are as follows:
122
-
123
- - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
124
- - labels [N]: the labels for each image
125
- - scores [N]: the scores or each prediction.
102
+ :return: Tuple containing the height and width of a sample input image.
126
103
"""
127
104
raise NotImplementedError
128
105
129
- def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> List [Dict [str , np .ndarray ]]:
130
- """
131
- Perform prediction for a batch of inputs.
132
-
133
- :param x: Samples of shape NCHW or NHWC.
134
- :param batch_size: Batch size.
135
- :return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
136
- are as follows:
137
-
138
- - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
139
- - labels [N]: the labels for each image
140
- - scores [N]: the scores or each prediction.
141
- """
142
- predictions = []
143
-
144
- for x_i in tqdm (x , desc = "ObjectSeeker" , disable = not self .verbose ):
145
- base_preds , masked_preds = self ._masked_predictions (x_i , batch_size = batch_size , ** kwargs )
146
- pruned_preds = self ._prune_boxes (masked_preds , base_preds )
147
- unionized_preds = self ._unionize_clusters (pruned_preds )
148
-
149
- preds = {
150
- "boxes" : np .concatenate ([base_preds ["boxes" ], unionized_preds ["boxes" ]]),
151
- "labels" : np .concatenate ([base_preds ["labels" ], unionized_preds ["labels" ]]),
152
- "scores" : np .concatenate ([base_preds ["scores" ], unionized_preds ["scores" ]]),
153
- }
154
-
155
- predictions .append (preds )
156
-
157
- return predictions
158
-
106
+ @abc .abstractmethod
159
107
def _masked_predictions (
160
108
self , x_i : np .ndarray , batch_size : int = 128 , ** kwargs
161
109
) -> Tuple [Dict [str , np .ndarray ], Dict [str , np .ndarray ]]:
@@ -167,70 +115,7 @@ def _masked_predictions(
167
115
:batch_size: Batch size.
168
116
:return: Predictions for the base unmasked image and merged predictions for the masked image.
169
117
"""
170
- x_mask = np .repeat (x_i [np .newaxis ], self .num_lines * 4 + 1 , axis = 0 )
171
-
172
- if self .channels_first :
173
- height = self .input_shape [1 ]
174
- width = self .input_shape [2 ]
175
- else :
176
- height = self .input_shape [0 ]
177
- width = self .input_shape [1 ]
178
- x_mask = np .transpose (x_mask , (0 , 3 , 1 , 2 ))
179
-
180
- idx = 1
181
-
182
- # Left masks
183
- for k in range (1 , self .num_lines + 1 ):
184
- boundary = int (width / (self .num_lines + 1 ) * k )
185
- x_mask [idx , :, :, :boundary ] = 0
186
- idx += 1
187
-
188
- # Right masks
189
- for k in range (1 , self .num_lines + 1 ):
190
- boundary = width - int (width / (self .num_lines + 1 ) * k )
191
- x_mask [idx , :, :, boundary :] = 0
192
- idx += 1
193
-
194
- # Top masks
195
- for k in range (1 , self .num_lines + 1 ):
196
- boundary = int (height / (self .num_lines + 1 ) * k )
197
- x_mask [idx , :, :boundary , :] = 0
198
- idx += 1
199
-
200
- # Bottom masks
201
- for k in range (1 , self .num_lines + 1 ):
202
- boundary = height - int (height / (self .num_lines + 1 ) * k )
203
- x_mask [idx , :, boundary :, :] = 0
204
- idx += 1
205
-
206
- if not self .channels_first :
207
- x_mask = np .transpose (x_mask , (0 , 2 , 3 , 1 ))
208
-
209
- predictions = self ._predict_classifier (x = x_mask , batch_size = batch_size , ** kwargs )
210
- filtered_predictions = [
211
- non_maximum_suppression (
212
- pred , iou_threshold = self .iou_threshold , confidence_threshold = self .confidence_threshold
213
- )
214
- for pred in predictions
215
- ]
216
-
217
- # Extract base predictions
218
- base_predictions = filtered_predictions [0 ]
219
-
220
- # Extract and merge masked predictions
221
- boxes = np .concatenate ([pred ["boxes" ] for pred in filtered_predictions [1 :]])
222
- labels = np .concatenate ([pred ["labels" ] for pred in filtered_predictions [1 :]])
223
- scores = np .concatenate ([pred ["scores" ] for pred in filtered_predictions [1 :]])
224
- merged_predictions = {
225
- "boxes" : boxes ,
226
- "labels" : labels ,
227
- "scores" : scores ,
228
- }
229
- masked_predictions = non_maximum_suppression (
230
- merged_predictions , iou_threshold = self .iou_threshold , confidence_threshold = self .confidence_threshold
231
- )
232
-
233
- return base_predictions , masked_predictions
118
+ raise NotImplementedError
234
119
235
120
def _prune_boxes (
236
121
self , masked_preds : Dict [str , np .ndarray ], base_preds : Dict [str , np .ndarray ]
@@ -332,6 +217,36 @@ def _unionize_clusters(self, masked_preds: Dict[str, np.ndarray]) -> Dict[str, n
332
217
}
333
218
return unionized_predictions
334
219
220
+ def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> List [Dict [str , np .ndarray ]]:
221
+ """
222
+ Perform prediction for a batch of inputs.
223
+
224
+ :param x: Samples of shape NCHW or NHWC.
225
+ :param batch_size: Batch size.
226
+ :return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
227
+ are as follows:
228
+
229
+ - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
230
+ - labels [N]: the labels for each image
231
+ - scores [N]: the scores or each prediction.
232
+ """
233
+ predictions = []
234
+
235
+ for x_i in tqdm (x , desc = "ObjectSeeker" , disable = not self .verbose ):
236
+ base_preds , masked_preds = self ._masked_predictions (x_i , batch_size = batch_size , ** kwargs )
237
+ pruned_preds = self ._prune_boxes (masked_preds , base_preds )
238
+ unionized_preds = self ._unionize_clusters (pruned_preds )
239
+
240
+ preds = {
241
+ "boxes" : np .concatenate ([base_preds ["boxes" ], unionized_preds ["boxes" ]]),
242
+ "labels" : np .concatenate ([base_preds ["labels" ], unionized_preds ["labels" ]]),
243
+ "scores" : np .concatenate ([base_preds ["scores" ], unionized_preds ["scores" ]]),
244
+ }
245
+
246
+ predictions .append (preds )
247
+
248
+ return predictions
249
+
335
250
def certify (
336
251
self ,
337
252
x : np .ndarray ,
@@ -348,10 +263,7 @@ def certify(
348
263
:return: A list containing an array of bools for each bounding box per image indicating if the bounding
349
264
box is certified against the given patch.
350
265
"""
351
- if self .channels_first :
352
- _ , height , width = self .input_shape
353
- else :
354
- height , width , _ = self .input_shape
266
+ height , width = self ._image_dimensions ()
355
267
356
268
patch_size = np .sqrt (height * width * patch_size )
357
269
height_offset = offset * height
0 commit comments