Skip to content

Commit 25e283a

Browse files
MrParoskfacebook-github-bot
authored andcommitted
Add support of torchscript for rotated_nms
Summary: Related to issue: #2525 I had to change the types from float to double in the files for it to compile. Pull Request resolved: #2550 Reviewed By: theschnitz Differential Revision: D26169817 Pulled By: ppwwyyxx fbshipit-source-id: 0a43467b4eb99f11d95f219b777add6ce01fdc19
1 parent 1c557d3 commit 25e283a

File tree

6 files changed

+45
-10
lines changed

6 files changed

+45
-10
lines changed

detectron2/layers/csrc/nms_rotated/nms_rotated.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ namespace detectron2 {
77
at::Tensor nms_rotated_cpu(
88
const at::Tensor& dets,
99
const at::Tensor& scores,
10-
const float iou_threshold);
10+
const double iou_threshold);
1111

1212
#if defined(WITH_CUDA) || defined(WITH_HIP)
1313
at::Tensor nms_rotated_cuda(
1414
const at::Tensor& dets,
1515
const at::Tensor& scores,
16-
const float iou_threshold);
16+
const double iou_threshold);
1717
#endif
1818

1919
// Interface for Python
@@ -22,7 +22,7 @@ at::Tensor nms_rotated_cuda(
2222
inline at::Tensor nms_rotated(
2323
const at::Tensor& dets,
2424
const at::Tensor& scores,
25-
const float iou_threshold) {
25+
const double iou_threshold) {
2626
assert(dets.device().is_cuda() == scores.device().is_cuda());
2727
if (dets.device().is_cuda()) {
2828
#if defined(WITH_CUDA) || defined(WITH_HIP)

detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ template <typename scalar_t>
88
at::Tensor nms_rotated_cpu_kernel(
99
const at::Tensor& dets,
1010
const at::Tensor& scores,
11-
const float iou_threshold) {
11+
const double iou_threshold) {
1212
// nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel,
1313
// however, the code in this function is much shorter because
1414
// we delegate the IoU computation for rotated boxes to
@@ -63,7 +63,7 @@ at::Tensor nms_rotated_cpu(
6363
// input must be contiguous
6464
const at::Tensor& dets,
6565
const at::Tensor& scores,
66-
const float iou_threshold) {
66+
const double iou_threshold) {
6767
auto result = at::empty({0}, dets.options());
6868

6969
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] {

detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ int const threadsPerBlock = sizeof(unsigned long long) * 8;
2020
template <typename T>
2121
__global__ void nms_rotated_cuda_kernel(
2222
const int n_boxes,
23-
const float iou_threshold,
23+
const double iou_threshold,
2424
const T* dev_boxes,
2525
unsigned long long* dev_mask) {
2626
// nms_rotated_cuda_kernel is modified from torchvision's nms_cuda_kernel
@@ -81,7 +81,7 @@ at::Tensor nms_rotated_cuda(
8181
// input must be contiguous
8282
const at::Tensor& dets,
8383
const at::Tensor& scores,
84-
float iou_threshold) {
84+
double iou_threshold) {
8585
// using scalar_t = float;
8686
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
8787
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");

detectron2/layers/csrc/vision.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
115115
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation")
116116
.def(pybind11::init<>());
117117
}
118+
119+
#ifdef TORCH_LIBRARY
120+
TORCH_LIBRARY(detectron2, m) {
121+
m.def("nms_rotated", &nms_rotated);
122+
}
123+
#endif
118124
} // namespace detectron2

detectron2/layers/nms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
from torchvision.ops import boxes as box_ops
77
from torchvision.ops import nms # BC-compat
88

9+
from detectron2.utils.env import TORCH_VERSION
10+
11+
if TORCH_VERSION < (1, 7):
12+
from detectron2 import _C
13+
14+
nms_rotated_func = _C.nms_rotated
15+
else:
16+
nms_rotated_func = torch.ops.detectron2.nms_rotated
17+
918

1019
def batched_nms(
1120
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
@@ -93,9 +102,7 @@ def nms_rotated(boxes, scores, iou_threshold):
93102
keep (Tensor): int64 tensor with the indices of the elements that have been kept
94103
by Rotated NMS, sorted in decreasing order of scores
95104
"""
96-
from detectron2 import _C
97-
98-
return _C.nms_rotated(boxes, scores, iou_threshold)
105+
return nms_rotated_func(boxes, scores, iou_threshold)
99106

100107

101108
# Note: this function (batched_nms_rotated) might be moved into

tests/layers/test_nms_rotated.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from __future__ import absolute_import, division, print_function, unicode_literals
33
import numpy as np
44
import unittest
5+
from copy import deepcopy
56
import torch
67
from torchvision import ops
78

89
from detectron2.layers import batched_nms, batched_nms_rotated, nms_rotated
10+
from detectron2.utils.env import TORCH_VERSION
911
from detectron2.utils.testing import random_boxes
1012

1113

@@ -149,5 +151,25 @@ def test_nms_rotated_180_degrees_cpu(self):
149151
self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))
150152

151153

154+
class TestScriptable(unittest.TestCase):
155+
def setUp(self):
156+
class TestingModule(torch.nn.Module):
157+
def forward(self, boxes, scores, threshold):
158+
return nms_rotated(boxes, scores, threshold)
159+
160+
self.module = TestingModule()
161+
162+
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
163+
def test_scriptable_cpu(self):
164+
m = deepcopy(self.module).cpu()
165+
_ = torch.jit.script(m)
166+
167+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
168+
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
169+
def test_scriptable_cuda(self):
170+
m = deepcopy(self.module).cuda()
171+
_ = torch.jit.script(m)
172+
173+
152174
if __name__ == "__main__":
153175
unittest.main()

0 commit comments

Comments
 (0)