Skip to content

Commit 0a98929

Browse files
committed
address review comments
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 1142f31 commit 0a98929

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
def translate_predictions_xcycwh_to_x1y1x2y2(
43-
y_pred_xcycwh: "torch.Tensor", input_height: int, input_width: int
43+
y_pred_xcycwh: "torch.Tensor", height: int, width: int
4444
) -> List[Dict[str, "torch.Tensor"]]:
4545
"""
4646
Convert object detection predictions from xcycwh (YOLO) to x1y1x2y2 (torchvision).
@@ -60,8 +60,8 @@ def translate_predictions_xcycwh_to_x1y1x2y2(
6060
[
6161
torch.maximum((y_pred[:, 0] - y_pred[:, 2] / 2), torch.tensor(0, device=device)),
6262
torch.maximum((y_pred[:, 1] - y_pred[:, 3] / 2), torch.tensor(0, device=device)),
63-
torch.minimum((y_pred[:, 0] + y_pred[:, 2] / 2), torch.tensor(input_height, device=device)),
64-
torch.minimum((y_pred[:, 1] + y_pred[:, 3] / 2), torch.tensor(input_width, device=device)),
63+
torch.minimum((y_pred[:, 0] + y_pred[:, 2] / 2), torch.tensor(height, device=device)),
64+
torch.minimum((y_pred[:, 1] + y_pred[:, 3] / 2), torch.tensor(width, device=device)),
6565
]
6666
).permute((1, 0))
6767
labels = torch.argmax(y_pred[:, 5:], dim=1, keepdim=False)
@@ -79,7 +79,7 @@ def translate_predictions_xcycwh_to_x1y1x2y2(
7979

8080

8181
def translate_labels_x1y1x2y2_to_xcycwh(
82-
labels_x1y1x2y2: List[Dict[str, "torch.Tensor"]], input_height: int, input_width: int
82+
labels_x1y1x2y2: List[Dict[str, "torch.Tensor"]], height: int, width: int
8383
) -> "torch.Tensor":
8484
"""
8585
Translate object detection labels from x1y1x2y2 (torchvision) to xcycwh (YOLO).
@@ -102,8 +102,8 @@ def translate_labels_x1y1x2y2_to_xcycwh(
102102
labels[:, 2:6] = label_dict["boxes"]
103103

104104
# normalize bounding boxes to [0, 1]
105-
labels[:, 2:6:2] /= input_width
106-
labels[:, 3:6:2] /= input_height
105+
labels[:, 2:6:2] /= width
106+
labels[:, 3:6:2] /= height
107107

108108
# convert from x1y1x2y2 to xcycwh
109109
labels[:, 4] -= labels[:, 2]
@@ -387,9 +387,7 @@ def _get_losses(
387387
height = self.input_shape[0]
388388
width = self.input_shape[1]
389389

390-
labels_t = translate_labels_x1y1x2y2_to_xcycwh(
391-
labels_x1y1x2y2=y_preprocessed, input_height=height, input_width=width
392-
)
390+
labels_t = translate_labels_x1y1x2y2_to_xcycwh(labels_x1y1x2y2=y_preprocessed, height=height, width=width)
393391

394392
loss_components = self._model(x_grad, labels_t)
395393

@@ -491,7 +489,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
491489
predictions_xcycwh = self._model(i_batch)
492490

493491
predictions_x1y1x2y2 = translate_predictions_xcycwh_to_x1y1x2y2(
494-
y_pred_xcycwh=predictions_xcycwh, input_height=height, input_width=width
492+
y_pred_xcycwh=predictions_xcycwh, height=height, width=width
495493
)
496494

497495
for prediction_x1y1x2y2 in predictions_x1y1x2y2:
@@ -544,6 +542,7 @@ def fit( # pylint: disable=W0221
544542

545543
# Apply preprocessing and convert to tensors
546544
x_preprocessed, y_preprocessed_list = self._preprocess_and_convert_inputs(x=x, y=y, fit=True, no_grad=True)
545+
547546
# Cast to np.ndarray to use list indexing
548547
y_preprocessed = np.asarray(y_preprocessed_list)
549548

@@ -575,9 +574,7 @@ def fit( # pylint: disable=W0221
575574
self._optimizer.zero_grad()
576575

577576
# Form the loss function
578-
labels_t = translate_labels_x1y1x2y2_to_xcycwh(
579-
labels_x1y1x2y2=o_batch, input_height=height, input_width=width
580-
)
577+
labels_t = translate_labels_x1y1x2y2_to_xcycwh(labels_x1y1x2y2=o_batch, height=height, width=width)
581578
loss_components = self._model(i_batch, labels_t)
582579
if isinstance(loss_components, dict):
583580
loss = sum(loss_components.values())

0 commit comments

Comments
 (0)