Skip to content

Commit 211efad

Browse files
lyttonhaofacebook-github-bot
authored andcommitted
support square padding in backbone
Summary: X-link: facebookresearch/d2go#258 Support square padding case in backbone. Reviewed By: wat3rBro Differential Revision: D35552076 fbshipit-source-id: e4f7f4da62b6ee9b71686071ff6cf2747ecc90e0
1 parent 32c32e3 commit 211efad

File tree

8 files changed

+78
-12
lines changed

8 files changed

+78
-12
lines changed

detectron2/modeling/backbone/backbone.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from abc import ABCMeta, abstractmethod
3+
from typing import Dict
34
import torch.nn as nn
45

56
from detectron2.layers import ShapeSpec
@@ -39,6 +40,26 @@ def size_divisibility(self) -> int:
3940
"""
4041
return 0
4142

43+
@property
44+
def padding_constraints(self) -> Dict[str, int]:
45+
"""
46+
This property is a generalization of size_divisibility. Some backbones and training
47+
recipes require specific padding constraints, such as enforcing divisibility by a specific
48+
integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
49+
in :paper:vitdet). `padding_constraints` contains these optional items like:
50+
{
51+
"size_divisibility": int,
52+
"square": int,
53+
# Future options are possible
54+
}
55+
`size_divisibility` will read from here if presented and `square` indicates if requiring
56+
inputs to be padded to square. Set to None if no specific padding constraints.
57+
58+
TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
59+
could be generalized as TypedDict (Python 3.8+) to support more types in the future.
60+
"""
61+
return {}
62+
4263
def output_shape(self):
4364
"""
4465
Returns:

detectron2/modeling/backbone/fpn.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ class FPN(Backbone):
2323
_fuse_type: torch.jit.Final[str]
2424

2525
def __init__(
26-
self, bottom_up, in_features, out_channels, norm="", top_block=None, fuse_type="sum"
26+
self,
27+
bottom_up,
28+
in_features,
29+
out_channels,
30+
norm="",
31+
top_block=None,
32+
fuse_type="sum",
33+
square_pad=False,
2734
):
2835
"""
2936
Args:
@@ -103,13 +110,18 @@ def __init__(
103110
self._out_features = list(self._out_feature_strides.keys())
104111
self._out_feature_channels = {k: out_channels for k in self._out_features}
105112
self._size_divisibility = strides[-1]
113+
self._square_pad = square_pad
106114
assert fuse_type in {"avg", "sum"}
107115
self._fuse_type = fuse_type
108116

109117
@property
110118
def size_divisibility(self):
111119
return self._size_divisibility
112120

121+
@property
122+
def padding_constraints(self):
123+
return {"square": int(self._square_pad)}
124+
113125
def forward(self, x):
114126
"""
115127
Args:

detectron2/modeling/meta_arch/dense_detector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(
6262
self.head_in_features = sorted(shapes.keys(), key=lambda x: shapes[x].stride)
6363
else:
6464
self.head_in_features = head_in_features
65-
6665
self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
6766
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
6867

@@ -127,7 +126,11 @@ def preprocess_image(self, batched_inputs: List[Dict[str, Tensor]]):
127126
"""
128127
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
129128
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
130-
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
129+
images = ImageList.from_tensors(
130+
images,
131+
self.backbone.size_divisibility,
132+
padding_constraints=self.backbone.padding_constraints,
133+
)
131134
return images
132135

133136
def _transpose_dense_predictions(

detectron2/modeling/meta_arch/panoptic_fpn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ def forward(self, batched_inputs):
119119
assert "sem_seg" in batched_inputs[0]
120120
gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs]
121121
gt_sem_seg = ImageList.from_tensors(
122-
gt_sem_seg, self.backbone.size_divisibility, self.sem_seg_head.ignore_value
122+
gt_sem_seg,
123+
self.backbone.size_divisibility,
124+
self.sem_seg_head.ignore_value,
125+
self.backbone.padding_constraints,
123126
).tensor
124127
sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg)
125128

detectron2/modeling/meta_arch/rcnn.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,11 @@ def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
227227
"""
228228
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
229229
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
230-
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
230+
images = ImageList.from_tensors(
231+
images,
232+
self.backbone.size_divisibility,
233+
padding_constraints=self.backbone.padding_constraints,
234+
)
231235
return images
232236

233237
@staticmethod
@@ -305,7 +309,11 @@ def forward(self, batched_inputs):
305309
"""
306310
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
307311
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
308-
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
312+
images = ImageList.from_tensors(
313+
images,
314+
self.backbone.size_divisibility,
315+
padding_constraints=self.backbone.padding_constraints,
316+
)
309317
features = self.backbone(images.tensor)
310318

311319
if "instances" in batched_inputs[0]:

detectron2/modeling/meta_arch/semantic_seg.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,21 @@ def forward(self, batched_inputs):
9999
"""
100100
images = [x["image"].to(self.device) for x in batched_inputs]
101101
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
102-
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
102+
images = ImageList.from_tensors(
103+
images,
104+
self.backbone.size_divisibility,
105+
padding_constraints=self.backbone.padding_constraints,
106+
)
103107

104108
features = self.backbone(images.tensor)
105109

106110
if "sem_seg" in batched_inputs[0]:
107111
targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
108112
targets = ImageList.from_tensors(
109-
targets, self.backbone.size_divisibility, self.sem_seg_head.ignore_value
113+
targets,
114+
self.backbone.size_divisibility,
115+
self.sem_seg_head.ignore_value,
116+
self.backbone.padding_constraints,
110117
).tensor
111118
else:
112119
targets = None

detectron2/modeling/postprocessing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def detector_postprocess(
2323
`results.image_size` contains the input image resolution the detector sees.
2424
This object might be modified in-place.
2525
output_height, output_width: the desired output resolution.
26-
2726
Returns:
2827
Instances: the resized output from the model, based on the output resolution
2928
"""

detectron2/structures/image_list.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from __future__ import division
3-
from typing import Any, List, Tuple
3+
from typing import Any, Dict, List, Optional, Tuple
44
import torch
55
from torch import device
66
from torch.nn import functional as F
@@ -57,7 +57,10 @@ def device(self) -> device:
5757

5858
@staticmethod
5959
def from_tensors(
60-
tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0
60+
tensors: List[torch.Tensor],
61+
size_divisibility: int = 0,
62+
pad_value: float = 0.0,
63+
padding_constraints: Optional[Dict[str, int]] = None,
6164
) -> "ImageList":
6265
"""
6366
Args:
@@ -67,7 +70,11 @@ def from_tensors(
6770
size_divisibility (int): If `size_divisibility > 0`, add padding to ensure
6871
the common height and width is divisible by `size_divisibility`.
6972
This depends on the model and many models need a divisibility of 32.
70-
pad_value (float): value to pad
73+
pad_value (float): value to pad.
74+
padding_constraints (optional[Dict]): If given, it would follow the format as
75+
{"size_divisibility": int, "square": int}, where `size_divisibility` will overwrite
76+
the above one if presented and `square` indicates if require inputs to be padded to
77+
square.
7178
7279
Returns:
7380
an `ImageList`.
@@ -82,6 +89,12 @@ def from_tensors(
8289
image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes]
8390
max_size = torch.stack(image_sizes_tensor).max(0).values
8491

92+
if padding_constraints is not None:
93+
if padding_constraints.get("square", 0) > 0:
94+
# pad to square.
95+
max_size[0] = max_size[1] = max_size.max()
96+
if "size_divisibility" in padding_constraints:
97+
size_divisibility = padding_constraints["size_divisibility"]
8598
if size_divisibility > 1:
8699
stride = size_divisibility
87100
# the last two dims are H,W, both subject to divisibility requirement

0 commit comments

Comments
 (0)