diff --git a/projects/IDOL/idol/idol.py b/projects/IDOL/idol/idol.py index 008dab3..709ab0c 100644 --- a/projects/IDOL/idol/idol.py +++ b/projects/IDOL/idol/idol.py @@ -365,6 +365,8 @@ def inference(self, outputs, tracker, ori_size, image_sizes): det_bboxes = torch.cat([output_boxes[indices],box_score.unsqueeze(1)],dim=1) det_labels = torch.argmax(logits.sigmoid()[indices],dim=1) track_feats = output_embed[indices] + if isinstance(indices, torch.Tensor): + indices = indices.cpu() det_masks = output_mask[indices] bboxes, labels, ids, indices = tracker.match( bboxes=det_bboxes, @@ -542,4 +544,4 @@ def preprocess_coco_image(self, batched_inputs): - \ No newline at end of file +