diff --git a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py index 7d70113ae..24b42ba9d 100644 --- a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py +++ b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -45,7 +45,7 @@ def forward(self, outputs, orig_target_sizes): boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) else: - scores = F.softmax(logits)[:, :, :-1] + scores = F.softmax(logits, dim=-1) scores, labels = scores.max(dim=-1) boxes = bbox_pred if scores.shape[1] > self.num_top_queries: