Skip to content

Commit 9fb31b3

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
fix sem_seg inference without "height" and "width"
Summary: fix #3491 Reviewed By: sstsai-adl Differential Revision: D31061615 fbshipit-source-id: 1080687340721a31d3689ea5b2450eef891e6e33
1 parent 0316cb7 commit 9fb31b3

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

detectron2/modeling/meta_arch/semantic_seg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def forward(self, batched_inputs):
117117

118118
processed_results = []
119119
for result, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
120-
height = input_per_image.get("height")
121-
width = input_per_image.get("width")
120+
height = input_per_image.get("height", image_size[0])
121+
width = input_per_image.get("width", image_size[1])
122122
r = sem_seg_postprocess(result, image_size, height, width)
123123
processed_results.append({"sem_seg": r})
124124
return processed_results

tests/modeling/test_model_e2e.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_regular_bitmask_instances(h, w):
7373
return inst
7474

7575

76-
class ModelE2ETest:
76+
class InstanceModelE2ETest:
7777
def setUp(self):
7878
torch.manual_seed(43)
7979
self.model = get_model_no_weights(self.CONFIG_PATH)
@@ -115,7 +115,7 @@ def test_eval_tocpu(self):
115115
model(inputs)
116116

117117

118-
class MaskRCNNE2ETest(ModelE2ETest, unittest.TestCase):
118+
class MaskRCNNE2ETest(InstanceModelE2ETest, unittest.TestCase):
119119
CONFIG_PATH = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
120120

121121
def test_half_empty_data(self):
@@ -171,7 +171,7 @@ def test_autocast(self):
171171
self.assertEqual(out.scores.dtype, torch.float32) # scores comes from softmax
172172

173173

174-
class RetinaNetE2ETest(ModelE2ETest, unittest.TestCase):
174+
class RetinaNetE2ETest(InstanceModelE2ETest, unittest.TestCase):
175175
CONFIG_PATH = "COCO-Detection/retinanet_R_50_FPN_1x.yaml"
176176

177177
def test_inf_nan_data(self):
@@ -209,3 +209,19 @@ def test_autocast(self):
209209
out = self.model(inputs)[0]["instances"]
210210
self.assertEqual(out.pred_boxes.tensor.dtype, torch.float32)
211211
self.assertEqual(out.scores.dtype, torch.float16)
212+
213+
214+
class SemSegE2ETest(unittest.TestCase):
215+
CONFIG_PATH = "Misc/semantic_R_50_FPN_1x.yaml"
216+
217+
def setUp(self):
218+
torch.manual_seed(43)
219+
self.model = get_model_no_weights(self.CONFIG_PATH)
220+
221+
def _test_eval(self, input_sizes):
222+
inputs = [create_model_input(torch.rand(3, s[0], s[1])) for s in input_sizes]
223+
self.model.eval()
224+
self.model(inputs)
225+
226+
def test_forward(self):
227+
self._test_eval([(200, 250), (200, 249)])

0 commit comments

Comments
 (0)