Skip to content

Commit cf035f6

Browse files
authored
Fix tensor type compatibility in dynamic soft label assigner and RTMDet head (#4140)
* Fix tensor type compatibility in dynamic soft label assigner and RTMDet head * Update CHANGELOG
1 parent ec610a9 commit cf035f6

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ All notable changes to this project will be documented in this file.
126126
(<https://github.com/openvinotoolkit/training_extensions/pull/4107>)
127127
- Fix empty annotation in tiling
128128
(<https://github.com/openvinotoolkit/training_extensions/pull/4124>)
129+
- Fix tensor type compatibility in dynamic soft label assigner and RTMDet head
130+
(<https://github.com/openvinotoolkit/training_extensions/pull/4140>)
129131

130132
## \[v2.1.0\]
131133

src/otx/algo/common/utils/assigners/dynamic_soft_label_assigner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def assign(
196196
assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
197197
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
198198
max_overlaps = assigned_gt_inds.new_full((num_bboxes,), -INF, dtype=torch.float32)
199-
max_overlaps[valid_mask] = matched_pred_ious
199+
max_overlaps[valid_mask] = matched_pred_ious.to(max_overlaps)
200200
return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
201201

202202
def dynamic_k_matching(

src/otx/algo/detection/heads/rtmdet_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def _get_targets_single( # type: ignore[override]
574574
if len(pos_inds) > 0:
575575
# point-based
576576
pos_bbox_targets = sampling_result.pos_gt_bboxes
577-
bbox_targets[pos_inds, :] = pos_bbox_targets
577+
bbox_targets[pos_inds, :] = pos_bbox_targets.to(bbox_targets)
578578

579579
labels[pos_inds] = sampling_result.pos_gt_labels
580580
if self.train_cfg["pos_weight"] <= 0:

0 commit comments

Comments
 (0)