Skip to content

Commit e403bc6

Browse files
committed
Updating pytorch yolo wrapper to align with y_preprocessed (target) output format such that loss via ultralytics can be computed.
Signed-off-by: Kieran Fraser <[email protected]>
1 parent 78a8238 commit e403bc6

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

art/estimators/object_detection/pytorch_yolo_loss_wrapper.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,14 @@ def __init__(self, model, name):
4242
raise ImportError("The 'ultralytics' package is required for YOLO v8+ models but not installed.") from e
4343

4444
def forward(self, x, targets=None):
45+
"""Transforms the target to dict expected by model.loss"""
4546
if self.training:
46-
boxes = []
47-
labels = []
48-
indices = []
49-
for i, item in enumerate(targets):
50-
boxes.append(item["boxes"])
51-
labels.append(item["labels"])
52-
indices = indices + ([i] * len(item["labels"]))
53-
items = {
54-
"boxes": torch.cat(boxes) / x.shape[2],
55-
"labels": torch.cat(labels).type(torch.float32),
56-
"batch_idx": torch.tensor(indices),
57-
}
58-
items["bboxes"] = items.pop("boxes")
59-
items["cls"] = items.pop("labels")
47+
if targets is None:
48+
raise ValueError("Targets should not be None when training.")
49+
items = {}
50+
items["batch_idx"] = targets[:, 0]
51+
items["bboxes"] = targets[:, 2:6]
52+
items["cls"] = targets[:, 1]
6053
items["img"] = x
6154
loss, loss_components = self.model.loss(items)
6255
loss_components_dict = {"loss_total": loss.sum()}

0 commit comments

Comments
 (0)