44
55class PointAndBoxPromptGenerator :
66 def __init__ (self , n_positive_points , n_negative_points , dilation_strength ,
7- get_point_prompts = False , get_box_prompts = False ):
7+ get_point_prompts = True , get_box_prompts = False ):
88 self .n_positive_points = n_positive_points
99 self .n_negative_points = n_negative_points
1010 self .dilation_strength = dilation_strength
1111 self .get_box_prompts = get_box_prompts
1212 self .get_point_prompts = get_point_prompts
1313
14+ if self .get_point_prompts is False and self .get_box_prompts is False :
15+ raise ValueError ("You need to request for box/point prompts or both" )
16+
1417 def __call__ (self , gt , gt_id , center_coordinates , bbox_coordinates ):
1518 """
1619 Parameters:
@@ -80,7 +83,7 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
8083 coord_list .append (negative_coordinates )
8184 label_list .append (0 )
8285
83- # returns object-level masks per instance for cross-verification (fix it later)
86+ # returns object-level masks per instance for cross-verification (TODO: fix it later)
8487 if self .get_point_prompts is True and self .get_box_prompts is True : # we want points and box
8588 return coord_list , label_list , bbox_list , object_mask
8689
@@ -89,6 +92,3 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
8992
9093 elif self .get_point_prompts is False and self .get_box_prompts is True : # we want only boxes
9194 return None , None , bbox_list , object_mask
92- else :
93- assert self .get_point_prompts is False and self .get_box_prompts is False , \
94- "You need to request for box/point prompts or both"
0 commit comments