Skip to content

Commit 1608d9b

Browse files
authored
Ensure target class indices are of type long in loss calculations (#4143)
* Ensure target class indices are of type long in loss calculations * update changelog
1 parent 5d6f8d3 commit 1608d9b

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ All notable changes to this project will be documented in this file.
132132
(<https://github.com/openvinotoolkit/training_extensions/pull/4131>)
133133
- Fix tensor type compatibility in dynamic soft label assigner and RTMDet head
134134
(<https://github.com/openvinotoolkit/training_extensions/pull/4140>)
135+
- Fix DETR target class indices are of type long in loss calculations
136+
(<https://github.com/openvinotoolkit/training_extensions/pull/4143>)
135137

136138
## \[v2.1.0\]
137139

src/otx/algo/detection/losses/rtdetr_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def loss_labels_vfl(
7777
src_logits = outputs["pred_logits"]
7878
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
7979
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
80-
target_classes[idx] = target_classes_o
80+
target_classes[idx] = target_classes_o.long()
8181
target = nn.functional.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
8282

8383
target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)

src/otx/algo/detection/utils/matchers/hungarian_matcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(
7171
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
7272

7373
# Also concat the target labels and boxes
74-
tgt_ids = torch.cat([v["labels"] for v in targets])
74+
tgt_ids = torch.cat([v["labels"] for v in targets]).long()
7575
tgt_bbox = torch.cat([v["boxes"] for v in targets])
7676

7777
# Compute the classification cost. Contrary to the loss, we don't use the NLL,

0 commit comments

Comments
 (0)