@@ -1032,7 +1032,7 @@ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
10321032 return self .output
10331033
10341034 def draw_binary_mask (
1035- self , binary_mask , color = None , * , edge_color = None , text = None , alpha = 0.5 , area_threshold = 0
1035+ self , binary_mask , color = None , * , edge_color = None , text = None , alpha = 0.5 , area_threshold = 10
10361036 ):
10371037 """
10381038 Args:
@@ -1043,9 +1043,9 @@ def draw_binary_mask(
10431043 formats that are accepted. If None, will pick a random color.
10441044 edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
10451045 full list of formats that are accepted.
1046- text (str): if None, will be drawn in the object's center of mass.
1046+ text (str): if None, will be drawn on the object
10471047 alpha (float): blending efficient. Smaller values lead to more transparent masks.
1048- area_threshold (float): a connected component small than this will not be shown.
1048+ area_threshold (float): a connected component smaller than this area will not be shown.
10491049
10501050 Returns:
10511051 output (VisImage): image object with mask drawn.
@@ -1078,18 +1078,36 @@ def draw_binary_mask(
10781078 self .output .ax .imshow (rgba , extent = (0 , self .output .width , self .output .height , 0 ))
10791079
10801080 if text is not None and has_valid_segment :
1081- # TODO sometimes drawn on wrong objects. the heuristics here can improve.
10821081 lighter_color = self ._change_color_brightness (color , brightness_factor = 0.7 )
1083- _num_cc , cc_labels , stats , centroids = cv2 .connectedComponentsWithStats (binary_mask , 8 )
1084- largest_component_id = np .argmax (stats [1 :, - 1 ]) + 1
1085-
1086- # draw text on the largest component, as well as other very large components.
1087- for cid in range (1 , _num_cc ):
1088- if cid == largest_component_id or stats [cid , - 1 ] > _LARGE_MASK_AREA_THRESH :
1089- # median is more stable than centroid
1090- # center = centroids[largest_component_id]
1091- center = np .median ((cc_labels == cid ).nonzero (), axis = 1 )[::- 1 ]
1092- self .draw_text (text , center , color = lighter_color )
1082+ self ._draw_text_in_mask (binary_mask , text , lighter_color )
1083+ return self .output
1084+
1085+ def draw_soft_mask (self , soft_mask , color = None , * , text = None , alpha = 0.5 ):
1086+ """
1087+ Args:
1088+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1089+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1090+ formats that are accepted. If None, will pick a random color.
1091+ text (str): if None, will be drawn on the object
1092+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1093+
1094+ Returns:
1095+ output (VisImage): image object with mask drawn.
1096+ """
1097+ if color is None :
1098+ color = random_color (rgb = True , maximum = 1 )
1099+ color = mplc .to_rgb (color )
1100+
1101+ shape2d = (soft_mask .shape [0 ], soft_mask .shape [1 ])
1102+ rgba = np .zeros (shape2d + (4 ,), dtype = "float32" )
1103+ rgba [:, :, :3 ] = color
1104+ rgba [:, :, 3 ] = soft_mask * alpha
1105+ self .output .ax .imshow (rgba , extent = (0 , self .output .width , self .output .height , 0 ))
1106+
1107+ if text is not None :
1108+ lighter_color = self ._change_color_brightness (color , brightness_factor = 0.7 )
1109+ binary_mask = (soft_mask > 0.5 ).astype ("uint8" )
1110+ self ._draw_text_in_mask (binary_mask , text , lighter_color )
10931111 return self .output
10941112
10951113 def draw_polygon (self , segment , color , edge_color = None , alpha = 0.5 ):
@@ -1215,6 +1233,24 @@ def _convert_masks(self, masks_or_polygons):
12151233 ret .append (GenericMask (x , self .output .height , self .output .width ))
12161234 return ret
12171235
1236+ def _draw_text_in_mask (self , binary_mask , text , color ):
1237+ """
1238+ Find proper places to draw text given a binary mask.
1239+ """
1240+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1241+ _num_cc , cc_labels , stats , centroids = cv2 .connectedComponentsWithStats (binary_mask , 8 )
1242+ if stats [1 :, - 1 ].size == 0 :
1243+ return
1244+ largest_component_id = np .argmax (stats [1 :, - 1 ]) + 1
1245+
1246+ # draw text on the largest component, as well as other very large components.
1247+ for cid in range (1 , _num_cc ):
1248+ if cid == largest_component_id or stats [cid , - 1 ] > _LARGE_MASK_AREA_THRESH :
1249+ # median is more stable than centroid
1250+ # center = centroids[largest_component_id]
1251+ center = np .median ((cc_labels == cid ).nonzero (), axis = 1 )[::- 1 ]
1252+ self .draw_text (text , center , color = color )
1253+
12181254 def _convert_keypoints (self , keypoints ):
12191255 if isinstance (keypoints , Keypoints ):
12201256 keypoints = keypoints .tensor
0 commit comments