22from scipy .ndimage import binary_dilation
33
44
5- class PointPromptGenerator :
6- def __init__ (self , n_positive_points , n_negative_points , dilation_strength ):
5+ class PointAndBoxPromptGenerator :
6+ def __init__ (self , n_positive_points , n_negative_points , dilation_strength ,
7+ get_point_prompts = True , get_box_prompts = False ):
78 self .n_positive_points = n_positive_points
89 self .n_negative_points = n_negative_points
910 self .dilation_strength = dilation_strength
11+ self .get_box_prompts = get_box_prompts
12+ self .get_point_prompts = get_point_prompts
13+
14+ if self .get_point_prompts is False and self .get_box_prompts is False :
15+ raise ValueError ("You need to request for box/point prompts or both" )
1016
1117 def __call__ (self , gt , gt_id , center_coordinates , bbox_coordinates ):
1218 """
@@ -20,9 +26,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2026 label_list = []
2127
2228 # getting the center coordinate as the first positive point
23- coord_list .append (tuple (map (int , center_coordinates )))
29+ coord_list .append (tuple (map (int , center_coordinates ))) # to get int coords instead of float
2430 label_list .append (1 )
2531
32+ if self .get_box_prompts :
33+ bbox_list = [bbox_coordinates ]
34+
2635 object_mask = gt == gt_id + 1 # alloting a label id to obtain the coordinates of desired seeds
2736
2837 # getting the additional positive points by randomly sampling points from this mask
@@ -52,8 +61,10 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
5261 dilated_object = binary_dilation (object_mask , iterations = self .dilation_strength )
5362 background_mask = np .zeros (gt .shape )
5463 background_mask [bbox_coordinates [0 ]:bbox_coordinates [2 ], bbox_coordinates [1 ]:bbox_coordinates [3 ]] = 1
55- background_mask = abs (background_mask - dilated_object )
5664 background_mask = binary_dilation (background_mask , iterations = self .dilation_strength )
65+ background_mask = abs (
66+ background_mask .astype (np .float32 ) - dilated_object .astype (np .float32 )
67+ ) # casting booleans to do subtraction
5768
5869 n_negative_remaining = self .n_negative_points
5970 if n_negative_remaining > 0 :
@@ -72,4 +83,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7283 coord_list .append (negative_coordinates )
7384 label_list .append (0 )
7485
75- return coord_list , label_list , None , object_mask
86+ # returns object-level masks per instance for cross-verification (TODO: fix it later)
87+ if self .get_point_prompts is True and self .get_box_prompts is True : # we want points and box
88+ return coord_list , label_list , bbox_list , object_mask
89+
90+ elif self .get_point_prompts is True and self .get_box_prompts is False : # we want only points
91+ return coord_list , label_list , None , object_mask
92+
93+ elif self .get_point_prompts is False and self .get_box_prompts is True : # we want only boxes
94+ return None , None , bbox_list , object_mask
0 commit comments