22from scipy .ndimage import binary_dilation
33
44
5- class PointPromptGenerator :
6- def __init__ (self , n_positive_points , n_negative_points , dilation_strength ):
5+ class PointAndBoxPromptGenerator :
6+ def __init__ (self , n_positive_points , n_negative_points , dilation_strength , get_box_prompts = False ):
77 self .n_positive_points = n_positive_points
88 self .n_negative_points = n_negative_points
99 self .dilation_strength = dilation_strength
10+ self .get_box_prompts = get_box_prompts
1011
1112 def __call__ (self , gt , gt_id , center_coordinates , bbox_coordinates ):
1213 """
@@ -20,9 +21,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2021 label_list = []
2122
2223 # getting the center coordinate as the first positive point
23- coord_list .append (tuple (map (int , center_coordinates )))
24+ coord_list .append (tuple (map (int , center_coordinates ))) # to get int coords instead of float
2425 label_list .append (1 )
2526
27+ if self .get_box_prompts :
28+ bbox_list = [bbox_coordinates ]
29+
2630 object_mask = gt == gt_id + 1 # alloting a label id to obtain the coordinates of desired seeds
2731
2832 # getting the additional positive points by randomly sampling points from this mask
@@ -72,4 +76,8 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7276 coord_list .append (negative_coordinates )
7377 label_list .append (0 )
7478
75- return coord_list , label_list , None , object_mask
79+ # returns object-level masks per instance for cross-verification (fix it later)
80+ if self .get_box_prompts :
81+ return coord_list , label_list , bbox_list , object_mask
82+ else :
83+ return coord_list , label_list , None , object_mask
0 commit comments