Skip to content

Commit 6545979

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
fix ColorMode.IMAGE_BW of visualizer
Summary: fix #3486 Reviewed By: zhanghang1989 Differential Revision: D30974958 fbshipit-source-id: 68c7d041f3b5b64eb0c32b4fcc99ad6d13e3542f
1 parent ce5b1c5 commit 6545979

File tree

5 files changed

+36
-15
lines changed

5 files changed

+36
-15
lines changed

detectron2/config/defaults.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
# INPUT
5151
# -----------------------------------------------------------------------------
5252
_C.INPUT = CN()
53+
# By default, {MIN,MAX}_SIZE options are used in transforms.ResizeShortestEdge.
54+
# Please refer to ResizeShortestEdge for detailed definition.
5355
# Size of the smallest side of the image during training
5456
_C.INPUT.MIN_SIZE_TRAIN = (800,)
5557
# Sample size of smallest side by choice or random selection from range give by
@@ -258,7 +260,7 @@
258260
# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
259261
_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
260262
_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
261-
# RoI minibatch size *per image* (number of regions of interest [ROIs])
263+
# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training
262264
# Total number of RoIs per training minibatch =
263265
# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
264266
# E.g., a common configuration is: 512 * 16 = 8192

detectron2/data/transforms/augmentation_impl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def get_transform(self, image):
128128

129129
class ResizeShortestEdge(Augmentation):
130130
"""
131-
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
131+
Resize the image while keeping the aspect ratio unchanged.
132+
It attempts to scale the shorter edge to the given `short_edge_length`,
133+
as long as the longer edge does not exceed `max_size`.
132134
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
133135
"""
134136

detectron2/utils/video_visualizer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def draw_instance_predictions(self, frame, predictions):
9191

9292
if self._instance_mode == ColorMode.IMAGE_BW:
9393
# any() returns uint8 tensor
94-
frame_visualizer.output.img = frame_visualizer._create_grayscale_image(
95-
(masks.any(dim=0) > 0).numpy() if masks is not None else None
94+
frame_visualizer.output.reset_image(
95+
frame_visualizer._create_grayscale_image(
96+
(masks.any(dim=0) > 0).numpy() if masks is not None else None
97+
)
9698
)
9799
alpha = 0.3
98100
else:
@@ -128,8 +130,8 @@ def draw_panoptic_seg_predictions(
128130
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
129131

130132
if self._instance_mode == ColorMode.IMAGE_BW:
131-
frame_visualizer.output.img = frame_visualizer._create_grayscale_image(
132-
pred.non_empty_mask()
133+
frame_visualizer.output.reset_image(
134+
frame_visualizer._create_grayscale_image(pred.non_empty_mask())
133135
)
134136

135137
# draw mask for all semantic segments first i.e. "stuff"

detectron2/utils/visualizer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ class VisImage:
255255
def __init__(self, img, scale=1.0):
256256
"""
257257
Args:
258-
img (ndarray): an RGB image of shape (H, W, 3).
258+
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
259259
scale (float): scale the input image
260260
"""
261261
self.img = img
@@ -284,11 +284,17 @@ def _setup_figure(self, img):
284284
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
285285
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
286286
ax.axis("off")
287-
# Need to imshow this first so that other patches can be drawn on top
288-
ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
289-
290287
self.fig = fig
291288
self.ax = ax
289+
self.reset_image(img)
290+
291+
def reset_image(self, img):
292+
"""
293+
Args:
294+
img: same as in __init__
295+
"""
296+
img = img.astype("uint8")
297+
self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
292298

293299
def save(self, filepath):
294300
"""
@@ -404,10 +410,12 @@ def draw_instance_predictions(self, predictions):
404410
alpha = 0.5
405411

406412
if self._instance_mode == ColorMode.IMAGE_BW:
407-
self.output.img = self._create_grayscale_image(
408-
(predictions.pred_masks.any(dim=0) > 0).numpy()
409-
if predictions.has("pred_masks")
410-
else None
413+
self.output.reset_image(
414+
self._create_grayscale_image(
415+
(predictions.pred_masks.any(dim=0) > 0).numpy()
416+
if predictions.has("pred_masks")
417+
else None
418+
)
411419
)
412420
alpha = 0.3
413421

@@ -476,7 +484,7 @@ def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, al
476484
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
477485

478486
if self._instance_mode == ColorMode.IMAGE_BW:
479-
self.output.img = self._create_grayscale_image(pred.non_empty_mask())
487+
self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
480488

481489
# draw mask for all semantic segments first i.e. "stuff"
482490
for mask, sinfo in pred.semantic_masks():

tests/test_visualizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def test_BWmode_nomask(self):
138138
v = Visualizer(img, self.metadata, instance_mode=ColorMode.IMAGE_BW)
139139
v.draw_instance_predictions(inst)
140140

141+
# check that output is grayscale
142+
inst = inst[:0]
143+
v = Visualizer(img, self.metadata, instance_mode=ColorMode.IMAGE_BW)
144+
output = v.draw_instance_predictions(inst).get_image()
145+
self.assertTrue(np.allclose(output[:, :, 0], output[:, :, 1]))
146+
self.assertTrue(np.allclose(output[:, :, 0], output[:, :, 2]))
147+
141148
def test_draw_empty_mask_predictions(self):
142149
img, boxes, _, _, masks = self._random_data()
143150
num_inst = len(boxes)

0 commit comments

Comments
 (0)