Skip to content

Commit 3597228

Browse files
authored
Merge pull request #2321 from f4str/pytorch-yolo-rebase
Subclass `PyTorchYolo` and `PyTorchDetectionTransformer` off `PyTorchObjectDetector`
2 parents 80fd393 + 9e347e6 commit 3597228

File tree

11 files changed

+994
-1882
lines changed

11 files changed

+994
-1882
lines changed

art/estimators/certification/object_seeker/object_seeker.py

Lines changed: 37 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from sklearn.cluster import DBSCAN
5353
from tqdm.auto import tqdm
5454

55-
from art.utils import intersection_over_area, non_maximum_suppression
55+
from art.utils import intersection_over_area
5656

5757
logger = logging.getLogger(__name__)
5858

@@ -94,68 +94,16 @@ def __init__(
9494
self.epsilon = epsilon
9595
self.verbose = verbose
9696

97-
@property
9897
@abc.abstractmethod
99-
def channels_first(self) -> bool:
98+
def _image_dimensions(self) -> Tuple[int, int]:
10099
"""
101-
:return: Boolean to indicate index of the color channels in the sample `x`.
102-
"""
103-
pass
104-
105-
@property
106-
@abc.abstractmethod
107-
def input_shape(self) -> Tuple[int, ...]:
108-
"""
109-
:return: Shape of one input sample.
110-
"""
111-
pass
100+
Get the height and width of a sample input image.
112101
113-
@abc.abstractmethod
114-
def _predict_classifier(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[str, np.ndarray]]:
115-
"""
116-
Perform prediction for a batch of inputs.
117-
118-
:param x: Samples of shape NCHW or NHWC.
119-
:param batch_size: Batch size.
120-
:return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
121-
are as follows:
122-
123-
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
124-
- labels [N]: the labels for each image
125-
- scores [N]: the scores or each prediction.
102+
:return: Tuple containing the height and width of a sample input image.
126103
"""
127104
raise NotImplementedError
128105

129-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[str, np.ndarray]]:
130-
"""
131-
Perform prediction for a batch of inputs.
132-
133-
:param x: Samples of shape NCHW or NHWC.
134-
:param batch_size: Batch size.
135-
:return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
136-
are as follows:
137-
138-
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
139-
- labels [N]: the labels for each image
140-
- scores [N]: the scores or each prediction.
141-
"""
142-
predictions = []
143-
144-
for x_i in tqdm(x, desc="ObjectSeeker", disable=not self.verbose):
145-
base_preds, masked_preds = self._masked_predictions(x_i, batch_size=batch_size, **kwargs)
146-
pruned_preds = self._prune_boxes(masked_preds, base_preds)
147-
unionized_preds = self._unionize_clusters(pruned_preds)
148-
149-
preds = {
150-
"boxes": np.concatenate([base_preds["boxes"], unionized_preds["boxes"]]),
151-
"labels": np.concatenate([base_preds["labels"], unionized_preds["labels"]]),
152-
"scores": np.concatenate([base_preds["scores"], unionized_preds["scores"]]),
153-
}
154-
155-
predictions.append(preds)
156-
157-
return predictions
158-
106+
@abc.abstractmethod
159107
def _masked_predictions(
160108
self, x_i: np.ndarray, batch_size: int = 128, **kwargs
161109
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
@@ -167,70 +115,7 @@ def _masked_predictions(
167115
:batch_size: Batch size.
168116
:return: Predictions for the base unmasked image and merged predictions for the masked image.
169117
"""
170-
x_mask = np.repeat(x_i[np.newaxis], self.num_lines * 4 + 1, axis=0)
171-
172-
if self.channels_first:
173-
height = self.input_shape[1]
174-
width = self.input_shape[2]
175-
else:
176-
height = self.input_shape[0]
177-
width = self.input_shape[1]
178-
x_mask = np.transpose(x_mask, (0, 3, 1, 2))
179-
180-
idx = 1
181-
182-
# Left masks
183-
for k in range(1, self.num_lines + 1):
184-
boundary = int(width / (self.num_lines + 1) * k)
185-
x_mask[idx, :, :, :boundary] = 0
186-
idx += 1
187-
188-
# Right masks
189-
for k in range(1, self.num_lines + 1):
190-
boundary = width - int(width / (self.num_lines + 1) * k)
191-
x_mask[idx, :, :, boundary:] = 0
192-
idx += 1
193-
194-
# Top masks
195-
for k in range(1, self.num_lines + 1):
196-
boundary = int(height / (self.num_lines + 1) * k)
197-
x_mask[idx, :, :boundary, :] = 0
198-
idx += 1
199-
200-
# Bottom masks
201-
for k in range(1, self.num_lines + 1):
202-
boundary = height - int(height / (self.num_lines + 1) * k)
203-
x_mask[idx, :, boundary:, :] = 0
204-
idx += 1
205-
206-
if not self.channels_first:
207-
x_mask = np.transpose(x_mask, (0, 2, 3, 1))
208-
209-
predictions = self._predict_classifier(x=x_mask, batch_size=batch_size, **kwargs)
210-
filtered_predictions = [
211-
non_maximum_suppression(
212-
pred, iou_threshold=self.iou_threshold, confidence_threshold=self.confidence_threshold
213-
)
214-
for pred in predictions
215-
]
216-
217-
# Extract base predictions
218-
base_predictions = filtered_predictions[0]
219-
220-
# Extract and merge masked predictions
221-
boxes = np.concatenate([pred["boxes"] for pred in filtered_predictions[1:]])
222-
labels = np.concatenate([pred["labels"] for pred in filtered_predictions[1:]])
223-
scores = np.concatenate([pred["scores"] for pred in filtered_predictions[1:]])
224-
merged_predictions = {
225-
"boxes": boxes,
226-
"labels": labels,
227-
"scores": scores,
228-
}
229-
masked_predictions = non_maximum_suppression(
230-
merged_predictions, iou_threshold=self.iou_threshold, confidence_threshold=self.confidence_threshold
231-
)
232-
233-
return base_predictions, masked_predictions
118+
raise NotImplementedError
234119

235120
def _prune_boxes(
236121
self, masked_preds: Dict[str, np.ndarray], base_preds: Dict[str, np.ndarray]
@@ -332,6 +217,36 @@ def _unionize_clusters(self, masked_preds: Dict[str, np.ndarray]) -> Dict[str, n
332217
}
333218
return unionized_predictions
334219

220+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[str, np.ndarray]]:
221+
"""
222+
Perform prediction for a batch of inputs.
223+
224+
:param x: Samples of shape NCHW or NHWC.
225+
:param batch_size: Batch size.
226+
:return: Predictions of format `List[Dict[str, np.ndarray]]`, one for each input image. The fields of the Dict
227+
are as follows:
228+
229+
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
230+
- labels [N]: the labels for each image
231+
- scores [N]: the scores or each prediction.
232+
"""
233+
predictions = []
234+
235+
for x_i in tqdm(x, desc="ObjectSeeker", disable=not self.verbose):
236+
base_preds, masked_preds = self._masked_predictions(x_i, batch_size=batch_size, **kwargs)
237+
pruned_preds = self._prune_boxes(masked_preds, base_preds)
238+
unionized_preds = self._unionize_clusters(pruned_preds)
239+
240+
preds = {
241+
"boxes": np.concatenate([base_preds["boxes"], unionized_preds["boxes"]]),
242+
"labels": np.concatenate([base_preds["labels"], unionized_preds["labels"]]),
243+
"scores": np.concatenate([base_preds["scores"], unionized_preds["scores"]]),
244+
}
245+
246+
predictions.append(preds)
247+
248+
return predictions
249+
335250
def certify(
336251
self,
337252
x: np.ndarray,
@@ -348,10 +263,7 @@ def certify(
348263
:return: A list containing an array of bools for each bounding box per image indicating if the bounding
349264
box is certified against the given patch.
350265
"""
351-
if self.channels_first:
352-
_, height, width = self.input_shape
353-
else:
354-
height, width, _ = self.input_shape
266+
height, width = self._image_dimensions()
355267

356268
patch_size = np.sqrt(height * width * patch_size)
357269
height_offset = offset * height

0 commit comments

Comments
 (0)