Skip to content

Commit 9f8d35e

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add shapes_to_tensor that stacks scalars or ints in a traceable/scriptable way
Summary: The function is reused in 3 different places. Reviewed By: zhanghang1989 Differential Revision: D30527464 fbshipit-source-id: 90b0fdbabd296ba5f4d911a8481c7c454b947b7b
1 parent fb2f17b commit 9f8d35e

File tree

6 files changed

+36
-50
lines changed

6 files changed

+36
-50
lines changed

detectron2/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Linear,
1616
nonzero_tuple,
1717
cross_entropy,
18+
shapes_to_tensor,
1819
)
1920
from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
2021
from .aspp import ASPP

detectron2/layers/wrappers.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,33 @@
88
is implemented
99
"""
1010

11-
from typing import List
11+
from typing import List, Optional
1212
import torch
1313
from torch.nn import functional as F
1414

1515

16+
def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor:
17+
"""
18+
Turn a list of integer scalars or integer Tensor scalars into a vector,
19+
in a way that's both traceable and scriptable.
20+
21+
In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
22+
In scripting or eager, `x` should be a list of int.
23+
"""
24+
if torch.jit.is_scripting():
25+
return torch.as_tensor(x, device=device)
26+
if torch.jit.is_tracing():
27+
assert all(
28+
[isinstance(t, torch.Tensor) for t in x]
29+
), "Shape should be tensor during tracing!"
30+
# as_tensor should not be used in tracing because it records a constant
31+
ret = torch.stack(x)
32+
if ret.device != device: # avoid recording a hard-coded device if not necessary
33+
ret = ret.to(device=device)
34+
return ret
35+
return torch.as_tensor(x, device=device)
36+
37+
1638
def cat(tensors: List[torch.Tensor], dim: int = 0):
1739
"""
1840
Efficient version of torch.cat that avoids a copy if there is only a single element in a list

detectron2/modeling/poolers.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import nn
66
from torchvision.ops import RoIPool
77

8-
from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple
8+
from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
99
from detectron2.structures import Boxes
1010

1111
"""
@@ -58,13 +58,6 @@ def assign_boxes_to_levels(
5858
return level_assignments.to(torch.int64) - min_level
5959

6060

61-
def _fmt_box_list(box_tensor, batch_index: int):
62-
repeated_index = torch.full_like(
63-
box_tensor[:, :1], batch_index, dtype=box_tensor.dtype, device=box_tensor.device
64-
)
65-
return cat((repeated_index, box_tensor), dim=1)
66-
67-
6861
def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
6962
"""
7063
Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops
@@ -88,11 +81,13 @@ def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
8881
where batch index is the index in [0, N) identifying which batch image the
8982
rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from.
9083
"""
91-
pooler_fmt_boxes = cat(
92-
[_fmt_box_list(box_list.tensor, i) for i, box_list in enumerate(box_lists)], dim=0
84+
boxes = torch.cat([x.tensor for x in box_lists], dim=0)
85+
# __len__ returns Tensor in tracing.
86+
sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
87+
indices = torch.repeat_interleave(
88+
torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
9389
)
94-
95-
return pooler_fmt_boxes
90+
return cat([indices[:, None], boxes], dim=1)
9691

9792

9893
class ROIPooler(nn.Module):

detectron2/structures/image_list.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,7 @@
55
from torch import device
66
from torch.nn import functional as F
77

8-
9-
def _as_tensor(x: Tuple[int, int]) -> torch.Tensor:
10-
"""
11-
An equivalent of `torch.as_tensor`, but works under tracing if input
12-
is a list of tensor. `torch.as_tensor` will record a constant in tracing,
13-
but this function will use `torch.stack` instead.
14-
"""
15-
if torch.jit.is_scripting():
16-
return torch.as_tensor(x)
17-
if isinstance(x, (list, tuple)) and all([isinstance(t, torch.Tensor) for t in x]):
18-
return torch.stack(x)
19-
return torch.as_tensor(x)
8+
from detectron2.layers.wrappers import shapes_to_tensor
209

2110

2211
class ImageList(object):
@@ -90,7 +79,7 @@ def from_tensors(
9079
assert t.shape[:-2] == tensors[0].shape[:-2], t.shape
9180

9281
image_sizes = [(im.shape[-2], im.shape[-1]) for im in tensors]
93-
image_sizes_tensor = [_as_tensor(x) for x in image_sizes]
82+
image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes]
9483
max_size = torch.stack(image_sizes_tensor).max(0).values
9584

9685
if size_divisibility > 1:

projects/PointRend/point_rend/point_features.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from torch.nn import functional as F
44

5-
from detectron2.layers import cat
5+
from detectron2.layers import cat, shapes_to_tensor
66
from detectron2.structures import BitMasks, Boxes
77

88

@@ -16,15 +16,6 @@
1616
"""
1717

1818

19-
def _as_tensor(x):
20-
"""
21-
An equivalent of `torch.as_tensor`, but works under tracing.
22-
"""
23-
if isinstance(x, (list, tuple)) and all([isinstance(t, torch.Tensor) for t in x]):
24-
return torch.stack(x)
25-
return torch.as_tensor(x)
26-
27-
2819
def point_sample(input, point_coords, **kwargs):
2920
"""
3021
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
@@ -182,7 +173,7 @@ def point_sample_fine_grained_features(features_list, feature_scales, boxes, poi
182173
point_features_per_image = []
183174
for idx_feature, feature_map in enumerate(features_list):
184175
h, w = feature_map.shape[-2:]
185-
scale = _as_tensor([w, h]) / feature_scales[idx_feature]
176+
scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature]
186177
point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device)
187178
point_features_per_image.append(
188179
point_sample(

tests/modeling/test_roi_pooler.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import unittest
44
import torch
55

6-
from detectron2.modeling.poolers import ROIPooler, _fmt_box_list
6+
from detectron2.modeling.poolers import ROIPooler
77
from detectron2.structures import Boxes, RotatedBoxes
88
from detectron2.utils.testing import random_boxes
99

@@ -114,18 +114,6 @@ def test_no_images(self):
114114
output = pooler.forward(features, [])
115115
self.assertEqual(output.shape, (0, C, 14, 14))
116116

117-
def test_fmt_box_list_tracing(self):
118-
class Model(torch.nn.Module):
119-
def forward(self, box_tensor):
120-
return _fmt_box_list(box_tensor, 0)
121-
122-
with torch.no_grad():
123-
func = torch.jit.trace(Model(), torch.ones(10, 4))
124-
125-
self.assertEqual(func(torch.ones(10, 4)).shape, (10, 5))
126-
self.assertEqual(func(torch.ones(5, 4)).shape, (5, 5))
127-
self.assertEqual(func(torch.ones(20, 4)).shape, (20, 5))
128-
129117
def test_roi_pooler_tracing(self):
130118
class Model(torch.nn.Module):
131119
def __init__(self, roi):

0 commit comments

Comments
 (0)