Skip to content

Commit 389a146

Browse files
authored
Fix batch index for cascade roi head (#2078)
1 parent 5fd0e89 commit 389a146

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def cascade_roi_head__predict_bbox(self,
6969
new_rois = get_box_tensor(new_rois)
7070
rois = new_rois.reshape(-1, new_rois.shape[-1])
7171
# Add dummy batch index
72-
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
72+
rois = torch.cat([batch_index.flatten(0, 1), rois], dim=-1)
7373
cls_scores = sum(ms_scores) / float(len(ms_scores))
7474
bbox_preds = bbox_pred.reshape(batch_size, num_proposals_per_img, -1)
7575
rois = rois.reshape(batch_size, num_proposals_per_img, -1)

0 commit comments

Comments
 (0)