|
2 | 2 | from __future__ import absolute_import, division, print_function, unicode_literals |
3 | 3 | import numpy as np |
4 | 4 | import unittest |
| 5 | +from copy import deepcopy |
5 | 6 | import torch |
6 | 7 | from torchvision import ops |
7 | 8 |
|
8 | 9 | from detectron2.layers import batched_nms, batched_nms_rotated, nms_rotated |
| 10 | +from detectron2.utils.env import TORCH_VERSION |
9 | 11 | from detectron2.utils.testing import random_boxes |
10 | 12 |
|
11 | 13 |
|
@@ -149,5 +151,25 @@ def test_nms_rotated_180_degrees_cpu(self): |
149 | 151 | self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou)) |
150 | 152 |
|
151 | 153 |
|
| 154 | +class TestScriptable(unittest.TestCase): |
| 155 | + def setUp(self): |
| 156 | + class TestingModule(torch.nn.Module): |
| 157 | + def forward(self, boxes, scores, threshold): |
| 158 | + return nms_rotated(boxes, scores, threshold) |
| 159 | + |
| 160 | + self.module = TestingModule() |
| 161 | + |
| 162 | + @unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version") |
| 163 | + def test_scriptable_cpu(self): |
| 164 | + m = deepcopy(self.module).cpu() |
| 165 | + _ = torch.jit.script(m) |
| 166 | + |
| 167 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
| 168 | + @unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version") |
| 169 | + def test_scriptable_cuda(self): |
| 170 | + m = deepcopy(self.module).cuda() |
| 171 | + _ = torch.jit.script(m) |
| 172 | + |
| 173 | + |
152 | 174 | if __name__ == "__main__": |
153 | 175 | unittest.main() |
0 commit comments