@@ -14,7 +14,7 @@ def __init__(self, n_positive_points, n_negative_points, dilation_strength,
1414 if self .get_point_prompts is False and self .get_box_prompts is False :
1515 raise ValueError ("You need to request for box/point prompts or both" )
1616
17- def __call__ (self , gt , gt_id , center_coordinates , bbox_coordinates ):
17+ def __call__ (self , gt , gt_id , bbox_coordinates , center_coordinates = None ):
1818 """
1919 Parameters:
2020 gt: True Labels
@@ -25,17 +25,23 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2525 coord_list = []
2626 label_list = []
2727
28- # getting the center coordinate as the first positive point
29- coord_list .append (tuple (map (int , center_coordinates ))) # to get int coords instead of float
30- label_list .append (1 )
28+ if center_coordinates is not None :
29+ # getting the center coordinate as the first positive point (OPTIONAL)
30+ coord_list .append (tuple (map (int , center_coordinates ))) # to get int coords instead of float
31+ label_list .append (1 )
32+
33+ # getting the additional positive points by randomly sampling points from this mask except the center coordinate
34+ n_positive_remaining = self .n_positive_points - 1
35+
36+ else :
37+ # need to sample "self.n_positive_points" number of points
38+ n_positive_remaining = self .n_positive_points
3139
3240 if self .get_box_prompts :
3341 bbox_list = [bbox_coordinates ]
3442
3543 object_mask = gt == gt_id
3644
37- # getting the additional positive points by randomly sampling points from this mask
38- n_positive_remaining = self .n_positive_points - 1
3945 if n_positive_remaining > 0 :
4046 # all coordinates of our current object
4147 object_coordinates = np .where (object_mask )
0 commit comments