Skip to content

Commit 8745a31

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
call batched_nms from torchvision directly
Summary: torchvision already has the thresholding logic. The threshold in torchvision is not good, which will be fixed in pytorch/vision#4990 Reviewed By: zhanghang1989 Differential Revision: D32661585 fbshipit-source-id: 971b68a914c2bba29dff37298c38e31b02735c1d
1 parent f14e631 commit 8745a31

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

detectron2/layers/nms.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,23 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) Facebook, Inc. and its affiliates.
33

4-
from typing import List
54
import torch
65
from torchvision.ops import boxes as box_ops
7-
from torchvision.ops import nms # BC-compat
6+
from torchvision.ops import nms # noqa . for compatibility
87

98

109
def batched_nms(
1110
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
1211
):
1312
"""
14-
Same as torchvision.ops.boxes.batched_nms, but safer.
13+
Same as torchvision.ops.boxes.batched_nms, but with float().
1514
"""
1615
assert boxes.shape[-1] == 4
17-
# TODO may need better strategy.
18-
# Investigate after having a fully-cuda NMS op.
19-
if len(boxes) < 40000:
20-
# fp16 does not have enough range for batched NMS
21-
return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
22-
23-
result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
24-
for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()):
25-
mask = (idxs == id).nonzero().view(-1)
26-
keep = nms(boxes[mask], scores[mask], iou_threshold)
27-
result_mask[mask[keep]] = True
28-
keep = result_mask.nonzero().view(-1)
29-
keep = keep[scores[keep].argsort(descending=True)]
30-
return keep
16+
# Note: Torchvision already has a strategy (https://github.com/pytorch/vision/issues/1311)
17+
# to decide whether to use coordinate trick or for loop to implement batched_nms. So we
18+
# just call it directly.
19+
# Fp16 does not have enough range for batched NMS, so adding float().
20+
return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
3121

3222

3323
# Note: this function (nms_rotated) might be moved into

0 commit comments

Comments
 (0)