5252from sklearn .cluster import DBSCAN
5353from tqdm .auto import tqdm
5454
55- from art .utils import intersection_over_area , non_maximum_suppression
55+ from art .utils import intersection_over_area
5656
5757logger = logging .getLogger (__name__ )
5858
@@ -94,68 +94,16 @@ def __init__(
9494 self .epsilon = epsilon
9595 self .verbose = verbose
9696
97- @property
9897 @abc .abstractmethod
99- def channels_first (self ) -> bool :
98+ def _image_dimensions (self ) -> Tuple [ int , int ] :
10099 """
101- :return: Boolean to indicate index of the color channels in the sample `x`.
102- """
103- pass
104-
105- @property
106- @abc .abstractmethod
107- def input_shape (self ) -> Tuple [int , ...]:
108- """
109- :return: Shape of one input sample.
110- """
111- pass
100+ Get the height and width of a sample input image.
112101
113- @abc .abstractmethod
114- def _predict_classifier (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> List [Dict [str , np .ndarray ]]:
115- """
116- Perform prediction for a batch of inputs.
117-
118- :param x: Samples of shape NCHW or NHWC.
119- :param batch_size: Batch size.
120- :return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
121- are as follows:
122-
123- - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
124- - labels [N]: the labels for each image
125- - scores [N]: the scores or each prediction.
102+ :return: Tuple containing the height and width of a sample input image.
126103 """
127104 raise NotImplementedError
128105
129- def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> List [Dict [str , np .ndarray ]]:
130- """
131- Perform prediction for a batch of inputs.
132-
133- :param x: Samples of shape NCHW or NHWC.
134- :param batch_size: Batch size.
135- :return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
136- are as follows:
137-
138- - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
139- - labels [N]: the labels for each image
140- - scores [N]: the scores or each prediction.
141- """
142- predictions = []
143-
144- for x_i in tqdm (x , desc = "ObjectSeeker" , disable = not self .verbose ):
145- base_preds , masked_preds = self ._masked_predictions (x_i , batch_size = batch_size , ** kwargs )
146- pruned_preds = self ._prune_boxes (masked_preds , base_preds )
147- unionized_preds = self ._unionize_clusters (pruned_preds )
148-
149- preds = {
150- "boxes" : np .concatenate ([base_preds ["boxes" ], unionized_preds ["boxes" ]]),
151- "labels" : np .concatenate ([base_preds ["labels" ], unionized_preds ["labels" ]]),
152- "scores" : np .concatenate ([base_preds ["scores" ], unionized_preds ["scores" ]]),
153- }
154-
155- predictions .append (preds )
156-
157- return predictions
158-
106+ @abc .abstractmethod
159107 def _masked_predictions (
160108 self , x_i : np .ndarray , batch_size : int = 128 , ** kwargs
161109 ) -> Tuple [Dict [str , np .ndarray ], Dict [str , np .ndarray ]]:
@@ -167,70 +115,7 @@ def _masked_predictions(
167115 :batch_size: Batch size.
168116 :return: Predictions for the base unmasked image and merged predictions for the masked image.
169117 """
170- x_mask = np .repeat (x_i [np .newaxis ], self .num_lines * 4 + 1 , axis = 0 )
171-
172- if self .channels_first :
173- height = self .input_shape [1 ]
174- width = self .input_shape [2 ]
175- else :
176- height = self .input_shape [0 ]
177- width = self .input_shape [1 ]
178- x_mask = np .transpose (x_mask , (0 , 3 , 1 , 2 ))
179-
180- idx = 1
181-
182- # Left masks
183- for k in range (1 , self .num_lines + 1 ):
184- boundary = int (width / (self .num_lines + 1 ) * k )
185- x_mask [idx , :, :, :boundary ] = 0
186- idx += 1
187-
188- # Right masks
189- for k in range (1 , self .num_lines + 1 ):
190- boundary = width - int (width / (self .num_lines + 1 ) * k )
191- x_mask [idx , :, :, boundary :] = 0
192- idx += 1
193-
194- # Top masks
195- for k in range (1 , self .num_lines + 1 ):
196- boundary = int (height / (self .num_lines + 1 ) * k )
197- x_mask [idx , :, :boundary , :] = 0
198- idx += 1
199-
200- # Bottom masks
201- for k in range (1 , self .num_lines + 1 ):
202- boundary = height - int (height / (self .num_lines + 1 ) * k )
203- x_mask [idx , :, boundary :, :] = 0
204- idx += 1
205-
206- if not self .channels_first :
207- x_mask = np .transpose (x_mask , (0 , 2 , 3 , 1 ))
208-
209- predictions = self ._predict_classifier (x = x_mask , batch_size = batch_size , ** kwargs )
210- filtered_predictions = [
211- non_maximum_suppression (
212- pred , iou_threshold = self .iou_threshold , confidence_threshold = self .confidence_threshold
213- )
214- for pred in predictions
215- ]
216-
217- # Extract base predictions
218- base_predictions = filtered_predictions [0 ]
219-
220- # Extract and merge masked predictions
221- boxes = np .concatenate ([pred ["boxes" ] for pred in filtered_predictions [1 :]])
222- labels = np .concatenate ([pred ["labels" ] for pred in filtered_predictions [1 :]])
223- scores = np .concatenate ([pred ["scores" ] for pred in filtered_predictions [1 :]])
224- merged_predictions = {
225- "boxes" : boxes ,
226- "labels" : labels ,
227- "scores" : scores ,
228- }
229- masked_predictions = non_maximum_suppression (
230- merged_predictions , iou_threshold = self .iou_threshold , confidence_threshold = self .confidence_threshold
231- )
232-
233- return base_predictions , masked_predictions
118+ raise NotImplementedError
234119
235120 def _prune_boxes (
236121 self , masked_preds : Dict [str , np .ndarray ], base_preds : Dict [str , np .ndarray ]
@@ -332,6 +217,36 @@ def _unionize_clusters(self, masked_preds: Dict[str, np.ndarray]) -> Dict[str, n
332217 }
333218 return unionized_predictions
334219
220+ def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> List [Dict [str , np .ndarray ]]:
221+ """
222+ Perform prediction for a batch of inputs.
223+
224+ :param x: Samples of shape NCHW or NHWC.
225+ :param batch_size: Batch size.
226+ :return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
227+ are as follows:
228+
229+ - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
230+ - labels [N]: the labels for each image
231+ - scores [N]: the scores or each prediction.
232+ """
233+ predictions = []
234+
235+ for x_i in tqdm (x , desc = "ObjectSeeker" , disable = not self .verbose ):
236+ base_preds , masked_preds = self ._masked_predictions (x_i , batch_size = batch_size , ** kwargs )
237+ pruned_preds = self ._prune_boxes (masked_preds , base_preds )
238+ unionized_preds = self ._unionize_clusters (pruned_preds )
239+
240+ preds = {
241+ "boxes" : np .concatenate ([base_preds ["boxes" ], unionized_preds ["boxes" ]]),
242+ "labels" : np .concatenate ([base_preds ["labels" ], unionized_preds ["labels" ]]),
243+ "scores" : np .concatenate ([base_preds ["scores" ], unionized_preds ["scores" ]]),
244+ }
245+
246+ predictions .append (preds )
247+
248+ return predictions
249+
335250 def certify (
336251 self ,
337252 x : np .ndarray ,
@@ -348,10 +263,7 @@ def certify(
348263 :return: A list containing an array of bools for each bounding box per image indicating if the bounding
349264 box is certified against the given patch.
350265 """
351- if self .channels_first :
352- _ , height , width = self .input_shape
353- else :
354- height , width , _ = self .input_shape
266+ height , width = self ._image_dimensions ()
355267
356268 patch_size = np .sqrt (height * width * patch_size )
357269 height_offset = offset * height
0 commit comments