Skip to content

Commit e9cbe90

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add a scriptable ResizeShortestEdge.get_output_shape
Summary: this function alone is useful in deployment Reviewed By: zhanghang1989 Differential Revision: D30801733 fbshipit-source-id: 792f8ca016f2c6782fc25c9bbaa302588597d087
1 parent 23486b6 commit e9cbe90

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

detectron2/data/transforms/augmentation_impl.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import sys
88
from typing import Tuple
9+
import torch
910
from fvcore.transforms.transform import (
1011
BlendTransform,
1112
CropTransform,
@@ -131,6 +132,7 @@ class ResizeShortestEdge(Augmentation):
131132
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
132133
"""
133134

135+
@torch.jit.unused
134136
def __init__(
135137
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
136138
):
@@ -155,6 +157,7 @@ def __init__(
155157
)
156158
self._init(locals())
157159

160+
@torch.jit.unused
158161
def get_transform(self, image):
159162
h, w = image.shape[:2]
160163
if self.is_range:
@@ -164,18 +167,30 @@ def get_transform(self, image):
164167
if size == 0:
165168
return NoOpTransform()
166169

167-
scale = size * 1.0 / min(h, w)
170+
newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
171+
return ResizeTransform(h, w, newh, neww, self.interp)
172+
173+
@staticmethod
174+
def get_output_shape(
175+
oldh: int, oldw: int, short_edge_length: int, max_size: int
176+
) -> Tuple[int, int]:
177+
"""
178+
Compute the output size given input size and target short edge length.
179+
"""
180+
h, w = oldh, oldw
181+
size = short_edge_length * 1.0
182+
scale = size / min(h, w)
168183
if h < w:
169184
newh, neww = size, scale * w
170185
else:
171186
newh, neww = scale * h, size
172-
if max(newh, neww) > self.max_size:
173-
scale = self.max_size * 1.0 / max(newh, neww)
187+
if max(newh, neww) > max_size:
188+
scale = max_size * 1.0 / max(newh, neww)
174189
newh = newh * scale
175190
neww = neww * scale
176191
neww = int(neww + 0.5)
177192
newh = int(newh + 0.5)
178-
return ResizeTransform(h, w, newh, neww, self.interp)
193+
return (newh, neww)
179194

180195

181196
class ResizeScale(Augmentation):
@@ -393,7 +408,7 @@ def get_crop_size(self, image_size):
393408
cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
394409
return ch, cw
395410
else:
396-
NotImplementedError("Unknown crop type {}".format(self.crop_type))
411+
raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
397412

398413

399414
class RandomCrop_CategoryAreaConstraint(Augmentation):

detectron2/export/torchscript.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def dump_torchscript_IR(model, dir):
7171
model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
7272
dir (str): output directory to dump files.
7373
"""
74+
dir = os.path.expanduser(dir)
7475
PathManager.mkdirs(dir)
7576

7677
def _get_script_mod(mod):

detectron2/modeling/mmdet_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
# "Neck" weights, if any, are part of neck itself. This is the interface
8080
# of mmdet so we follow it. Reference:
8181
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py
82-
logger.info(f"Initializing mmdet backbone weights...")
82+
logger.info("Initializing mmdet backbone weights...")
8383
self.backbone.init_weights()
8484
# train() in mmdet modules is non-trivial, and has to be explicitly
8585
# called. Reference:

tests/data/test_transforms.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import numpy as np
66
import unittest
77
from unittest import mock
8+
import torch
89
from PIL import Image, ImageOps
10+
from torch.nn import functional as F
911

1012
from detectron2.config import get_cfg
1113
from detectron2.data import detection_utils
@@ -225,7 +227,22 @@ def test_resize_transform(self):
225227
in_img = np.random.randint(0, 255, size=in_shape, dtype=np.uint8)
226228
tfm = T.ResizeTransform(in_shape[0], in_shape[1], out_shape[0], out_shape[1])
227229
out_img = tfm.apply_image(in_img)
228-
self.assertTrue(out_img.shape == out_shape)
230+
self.assertEqual(out_img.shape, out_shape)
231+
232+
def test_resize_shorted_edge_scriptable(self):
233+
def f(image):
234+
newh, neww = T.ResizeShortestEdge.get_output_shape(
235+
image.shape[-2], image.shape[-1], 80, 133
236+
)
237+
return F.interpolate(image.unsqueeze(0), size=(newh, neww))
238+
239+
input = torch.randn(3, 10, 10)
240+
script_f = torch.jit.script(f)
241+
self.assertTrue(torch.allclose(f(input), script_f(input)))
242+
243+
# generalize to new shapes
244+
input = torch.randn(3, 8, 100)
245+
self.assertTrue(torch.allclose(f(input), script_f(input)))
229246

230247
def test_extent_transform(self):
231248
input_shapes = [(100, 100), (100, 100, 1), (100, 100, 3)]

0 commit comments

Comments
 (0)