Skip to content

Commit f14e631

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add visualizer.draw_soft_mask
Reviewed By: zhanghang1989 Differential Revision: D32688852 fbshipit-source-id: a47f3b507fa520b1d66d69a90b02126489f3debb
1 parent 1ad5759 commit f14e631

File tree

4 files changed

+71
-17
lines changed

4 files changed

+71
-17
lines changed

INSTALL.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Click each issue for its solutions:
6767

6868
<details>
6969
<summary>
70-
Undefined symbols that contains TH,aten,torch,caffe2.
70+
Undefined symbols that looks like "TH..","at::Tensor...","torch..."
7171
</summary>
7272
<br/>
7373

@@ -96,7 +96,7 @@ compiled with the version of PyTorch you're running. See the previous common iss
9696

9797
<details>
9898
<summary>
99-
Undefined C++ symbols (e.g. GLIBCXX) or C++ symbols not found.
99+
Undefined C++ symbols (e.g. "GLIBCXX..") or C++ symbols not found.
100100
</summary>
101101
<br/>
102102
Usually it's because the library is compiled with a newer C++ compiler but run with an old C++ runtime.

detectron2/utils/visualizer.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
10321032
return self.output
10331033

10341034
def draw_binary_mask(
1035-
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0
1035+
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10
10361036
):
10371037
"""
10381038
Args:
@@ -1043,9 +1043,9 @@ def draw_binary_mask(
10431043
formats that are accepted. If None, will pick a random color.
10441044
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
10451045
full list of formats that are accepted.
1046-
text (str): if None, will be drawn in the object's center of mass.
1046+
text (str): if None, will be drawn on the object
10471047
alpha (float): blending efficient. Smaller values lead to more transparent masks.
1048-
area_threshold (float): a connected component small than this will not be shown.
1048+
area_threshold (float): a connected component smaller than this area will not be shown.
10491049
10501050
Returns:
10511051
output (VisImage): image object with mask drawn.
@@ -1078,18 +1078,36 @@ def draw_binary_mask(
10781078
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
10791079

10801080
if text is not None and has_valid_segment:
1081-
# TODO sometimes drawn on wrong objects. the heuristics here can improve.
10821081
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1083-
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1084-
largest_component_id = np.argmax(stats[1:, -1]) + 1
1085-
1086-
# draw text on the largest component, as well as other very large components.
1087-
for cid in range(1, _num_cc):
1088-
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1089-
# median is more stable than centroid
1090-
# center = centroids[largest_component_id]
1091-
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1092-
self.draw_text(text, center, color=lighter_color)
1082+
self._draw_text_in_mask(binary_mask, text, lighter_color)
1083+
return self.output
1084+
1085+
def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1086+
"""
1087+
Args:
1088+
soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1089+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
1090+
formats that are accepted. If None, will pick a random color.
1091+
text (str): if None, will be drawn on the object
1092+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
1093+
1094+
Returns:
1095+
output (VisImage): image object with mask drawn.
1096+
"""
1097+
if color is None:
1098+
color = random_color(rgb=True, maximum=1)
1099+
color = mplc.to_rgb(color)
1100+
1101+
shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1102+
rgba = np.zeros(shape2d + (4,), dtype="float32")
1103+
rgba[:, :, :3] = color
1104+
rgba[:, :, 3] = soft_mask * alpha
1105+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1106+
1107+
if text is not None:
1108+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1109+
binary_mask = (soft_mask > 0.5).astype("uint8")
1110+
self._draw_text_in_mask(binary_mask, text, lighter_color)
10931111
return self.output
10941112

10951113
def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
@@ -1215,6 +1233,24 @@ def _convert_masks(self, masks_or_polygons):
12151233
ret.append(GenericMask(x, self.output.height, self.output.width))
12161234
return ret
12171235

1236+
def _draw_text_in_mask(self, binary_mask, text, color):
1237+
"""
1238+
Find proper places to draw text given a binary mask.
1239+
"""
1240+
# TODO sometimes drawn on wrong objects. the heuristics here can improve.
1241+
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1242+
if stats[1:, -1].size == 0:
1243+
return
1244+
largest_component_id = np.argmax(stats[1:, -1]) + 1
1245+
1246+
# draw text on the largest component, as well as other very large components.
1247+
for cid in range(1, _num_cc):
1248+
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1249+
# median is more stable than centroid
1250+
# center = centroids[largest_component_id]
1251+
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1252+
self.draw_text(text, center, color=color)
1253+
12181254
def _convert_keypoints(self, keypoints):
12191255
if isinstance(keypoints, Keypoints):
12201256
keypoints = keypoints.tensor

docs/tutorials/models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ The dict may contain the following keys:
7171
* "image": `Tensor` in (C, H, W) format. The meaning of channels are defined by `cfg.INPUT.FORMAT`.
7272
Image normalization, if any, will be performed inside the model using
7373
`cfg.MODEL.PIXEL_{MEAN,STD}`.
74-
* "height", "width": the **desired** output height and width, which is not necessarily the same
74+
* "height", "width": the **desired** output height and width **in inference**, which is not necessarily the same
7575
as the height or width of the `image` field.
7676
For example, the `image` field contains the resized image, if resize is used as a preprocessing step.
7777
But you may want the outputs to be in **original** resolution.

tests/test_visualizer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ def test_draw_binary_mask(self):
213213
# red color is drawn on the image
214214
self.assertTrue(o[:, :, 0].sum() > 0)
215215

216+
def test_draw_soft_mask(self):
217+
img = np.random.rand(100, 100, 3) * 255
218+
img[:, :, 0] = 0 # remove red color
219+
mask = np.zeros((100, 100), dtype=np.float32)
220+
mask[30:50, 40:50] = 1.0
221+
cv2.GaussianBlur(mask, (21, 21), 10)
222+
223+
v = Visualizer(img)
224+
o = v.draw_soft_mask(mask, color="red", text="test")
225+
o = o.get_image().astype("float32")
226+
# red color is drawn on the image
227+
self.assertTrue(o[:, :, 0].sum() > 0)
228+
229+
# test draw empty mask
230+
v = Visualizer(img)
231+
o = v.draw_soft_mask(np.zeros((100, 100), dtype=np.float32), color="red", text="test")
232+
o = o.get_image().astype("float32")
233+
216234
def test_border_mask_with_holes(self):
217235
H, W = 200, 200
218236
img = np.zeros((H, W, 3))

0 commit comments

Comments
 (0)