Skip to content

Commit e5483a9

Browse files
committed
fix pytorch yolo input format
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 0cc12f4 commit e5483a9

File tree

2 files changed

+172
-153
lines changed

2 files changed

+172
-153
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,74 +43,77 @@ def translate_predictions_xcycwh_to_x1y1x2y2(
4343
y_pred_xcycwh: "torch.Tensor", input_height: int, input_width: int
4444
) -> List[Dict[str, "torch.Tensor"]]:
4545
"""
46-
Convert object detection predictions from xcycwh to x1y1x2y2 format.
46+
Convert object detection predictions from xcycwh (YOLO) to x1y1x2y2 (torchvision).
4747
48-
:param y_pred_xcycwh: Labels in format xcycwh.
49-
:return: Labels in format x1y1x2y2.
48+
:param y_pred_xcycwh: Object detection labels in format xcycwh (YOLO).
49+
:param height: Height of images in pixels.
50+
:param width: Width if images in pixels.
51+
:return: Object detection labels in format x1y1x2y2 (torchvision).
5052
"""
5153
import torch
5254

5355
y_pred_x1y1x2y2 = []
54-
55-
for i in range(y_pred_xcycwh.shape[0]):
56+
device = y_pred_xcycwh.device
57+
58+
for y_pred in y_pred_xcycwh:
59+
boxes = torch.vstack(
60+
[
61+
torch.maximum((y_pred[:, 0] - y_pred[:, 2] / 2), torch.tensor(0).to(device)),
62+
torch.maximum((y_pred[:, 1] - y_pred[:, 3] / 2), torch.tensor(0).to(device)),
63+
torch.minimum((y_pred[:, 0] + y_pred[:, 2] / 2), torch.tensor(input_height).to(device)),
64+
torch.minimum((y_pred[:, 1] + y_pred[:, 3] / 2), torch.tensor(input_width).to(device)),
65+
]
66+
).permute((1, 0))
67+
labels = torch.argmax(y_pred[:, 5:], dim=1, keepdim=False)
68+
scores = y_pred[:, 4]
5669

5770
y_i = {
58-
"boxes": torch.permute(
59-
torch.vstack(
60-
[
61-
torch.maximum(
62-
(y_pred_xcycwh[i, :, 0] - y_pred_xcycwh[i, :, 2] / 2),
63-
torch.tensor(0).to(y_pred_xcycwh.device),
64-
),
65-
torch.maximum(
66-
(y_pred_xcycwh[i, :, 1] - y_pred_xcycwh[i, :, 3] / 2),
67-
torch.tensor(0).to(y_pred_xcycwh.device),
68-
),
69-
torch.minimum(
70-
(y_pred_xcycwh[i, :, 0] + y_pred_xcycwh[i, :, 2] / 2),
71-
torch.tensor(input_height).to(y_pred_xcycwh.device),
72-
),
73-
torch.minimum(
74-
(y_pred_xcycwh[i, :, 1] + y_pred_xcycwh[i, :, 3] / 2),
75-
torch.tensor(input_width).to(y_pred_xcycwh.device),
76-
),
77-
]
78-
),
79-
(1, 0),
80-
),
81-
"labels": torch.argmax(y_pred_xcycwh[i, :, 5:], dim=1, keepdim=False),
82-
"scores": y_pred_xcycwh[i, :, 4],
71+
"boxes": boxes,
72+
"labels": labels,
73+
"scores": scores,
8374
}
8475

8576
y_pred_x1y1x2y2.append(y_i)
8677

8778
return y_pred_x1y1x2y2
8879

8980

90-
def translate_labels_art_to_yolov3(labels_art: List[Dict[str, "torch.Tensor"]]):
81+
def translate_labels_x1y1x2y2_to_xcycwh(
82+
labels_x1y1x2y2: List[Dict[str, "torch.Tensor"]], input_height: int, input_width: int
83+
) -> "torch.Tensor":
9184
"""
92-
Translate labels from ART to YOLO v3 and v5.
85+
Translate object detection labels from x1y1x2y2 (torchvision) to xcycwh (YOLO).
9386
94-
:param labels_art: Object detection labels in format ART (torchvision).
95-
:return: Object detection labels in format YOLO v3 and v5.
87+
:param labels_x1y1x2y2: Object detection labels in format x1y1x2y2 (torchvision).
88+
:param height: Height of images in pixels.
89+
:param width: Width if images in pixels.
90+
:return: Object detection labels in format xcycwh (YOLO).
9691
"""
9792
import torch
9893

99-
yolo_targets_list = []
94+
labels_xcycwh_list = []
95+
96+
for i, label_dict in enumerate(labels_x1y1x2y2):
97+
# create 2D tensor to encode labels and bounding boxes
98+
labels = torch.zeros(len(label_dict["boxes"]), 6)
99+
labels[:, 0] = i
100+
labels[:, 1] = label_dict["labels"]
101+
labels[:, 2:6] = label_dict["boxes"]
102+
103+
# normalize bounding boxes to [0, 1]
104+
labels[:, 2:6:2] /= input_width
105+
labels[:, 3:6:2] /= input_height
100106

101-
for i_dict, label_dict in enumerate(labels_art):
102-
num_detectors = label_dict["boxes"].size()[0]
103-
targets = torch.zeros(num_detectors, 6)
104-
targets[:, 0] = i_dict
105-
targets[:, 1] = label_dict["labels"]
106-
targets[:, 2:6] = label_dict["boxes"]
107-
targets[:, 4] = targets[:, 4] - targets[:, 2]
108-
targets[:, 5] = targets[:, 5] - targets[:, 3]
109-
yolo_targets_list.append(targets)
107+
# convert from x1y1x2y2 to xcycwh
108+
labels[:, 4] = labels[:, 4] - labels[:, 2]
109+
labels[:, 5] = labels[:, 5] - labels[:, 3]
110+
labels[:, 2] = labels[:, 2] + labels[:, 4] / 2
111+
labels[:, 3] = labels[:, 3] + labels[:, 5] / 2
112+
labels_xcycwh_list.append(labels)
110113

111-
yolo_targets = torch.vstack(yolo_targets_list)
114+
labels_xcycwh = torch.vstack(labels_xcycwh_list)
112115

113-
return yolo_targets
116+
return labels_xcycwh
114117

115118

116119
class PyTorchYolo(ObjectDetectorMixin, PyTorchEstimator):
@@ -274,8 +277,6 @@ def _get_losses(
274277
import torch
275278

276279
self._model.train()
277-
self.set_batchnorm(train=False)
278-
self.set_dropout(train=False)
279280

280281
# Apply preprocessing
281282
if self.all_framework_preprocessing:
@@ -344,7 +345,16 @@ def _get_losses(
344345
else:
345346
raise NotImplementedError("Combination of inputs and preprocessing not supported.")
346347

347-
labels_t = translate_labels_art_to_yolov3(labels_art=y_preprocessed)
348+
if self.channels_first:
349+
height = self.input_shape[1]
350+
width = self.input_shape[2]
351+
else:
352+
height = self.input_shape[0]
353+
width = self.input_shape[1]
354+
355+
labels_t = translate_labels_x1y1x2y2_to_xcycwh(
356+
labels_x1y1x2y2=y_preprocessed, input_height=height, input_width=width
357+
)
348358

349359
loss_components = self._model(inputs_t, labels_t)
350360

@@ -528,6 +538,13 @@ def fit( # pylint: disable=W0221
528538
else:
529539
x_preprocessed = torch.stack([transform(x_i / norm_factor).to(self.device) for x_i in x_preprocessed])
530540

541+
if self.channels_first:
542+
height = self.input_shape[1]
543+
width = self.input_shape[2]
544+
else:
545+
height = self.input_shape[0]
546+
width = self.input_shape[1]
547+
531548
# Convert labels into tensors, if needed
532549
if isinstance(y_preprocessed[0]["boxes"], np.ndarray):
533550
y_preprocessed_tensor = []
@@ -563,7 +580,9 @@ def fit( # pylint: disable=W0221
563580
self._optimizer.zero_grad()
564581

565582
# Form the loss function
566-
labels_t = translate_labels_art_to_yolov3(labels_art=o_batch)
583+
labels_t = translate_labels_x1y1x2y2_to_xcycwh(
584+
labels_x1y1x2y2=o_batch, input_height=height, input_width=width
585+
)
567586
loss_components = self._model(i_batch, labels_t)
568587
if isinstance(loss_components, dict):
569588
loss = sum(loss_components.values())

0 commit comments

Comments
 (0)