|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. |
3 | 3 |
|
4 | | -from typing import List |
5 | 4 | import torch |
6 | 5 | from torchvision.ops import boxes as box_ops |
7 | | -from torchvision.ops import nms # BC-compat |
| 6 | +from torchvision.ops import nms # noqa . for compatibility |
8 | 7 |
|
9 | 8 |
|
10 | 9 | def batched_nms( |
11 | 10 | boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float |
12 | 11 | ): |
13 | 12 | """ |
14 | | - Same as torchvision.ops.boxes.batched_nms, but safer. |
| 13 | + Same as torchvision.ops.boxes.batched_nms, but with float(). |
15 | 14 | """ |
16 | 15 | assert boxes.shape[-1] == 4 |
17 | | - # TODO may need better strategy. |
18 | | - # Investigate after having a fully-cuda NMS op. |
19 | | - if len(boxes) < 40000: |
20 | | - # fp16 does not have enough range for batched NMS |
21 | | - return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold) |
22 | | - |
23 | | - result_mask = scores.new_zeros(scores.size(), dtype=torch.bool) |
24 | | - for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()): |
25 | | - mask = (idxs == id).nonzero().view(-1) |
26 | | - keep = nms(boxes[mask], scores[mask], iou_threshold) |
27 | | - result_mask[mask[keep]] = True |
28 | | - keep = result_mask.nonzero().view(-1) |
29 | | - keep = keep[scores[keep].argsort(descending=True)] |
30 | | - return keep |
| 16 | + # Note: Torchvision already has a strategy (https://github.com/pytorch/vision/issues/1311) |
| 17 | + # to decide whether to use coordinate trick or for loop to implement batched_nms. So we |
| 18 | + # just call it directly. |
| 19 | + # Fp16 does not have enough range for batched NMS, so adding float(). |
| 20 | + return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold) |
31 | 21 |
|
32 | 22 |
|
33 | 23 | # Note: this function (nms_rotated) might be moved into |
|
0 commit comments