|
3 | 3 |
|
4 | 4 |
|
5 | 5 | class PointAndBoxPromptGenerator: |
6 | | - def __init__(self, n_positive_points, n_negative_points, dilation_strength, get_box_prompts=False): |
| 6 | + def __init__(self, n_positive_points, n_negative_points, dilation_strength, |
| 7 | + get_point_prompts=False, get_box_prompts=False): |
7 | 8 | self.n_positive_points = n_positive_points |
8 | 9 | self.n_negative_points = n_negative_points |
9 | 10 | self.dilation_strength = dilation_strength |
10 | 11 | self.get_box_prompts = get_box_prompts |
| 12 | + self.get_point_prompts = get_point_prompts |
11 | 13 |
|
12 | 14 | def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates): |
13 | 15 | """ |
@@ -77,7 +79,14 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates): |
77 | 79 | label_list.append(0) |
78 | 80 |
|
79 | 81 | # returns object-level masks per instance for cross-verification (fix it later) |
80 | | - if self.get_box_prompts: |
| 82 | + if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box |
81 | 83 | return coord_list, label_list, bbox_list, object_mask |
82 | | - else: |
| 84 | + |
| 85 | + elif self.get_point_prompts is True and self.get_box_prompts is False: # we want only points |
83 | 86 | return coord_list, label_list, None, object_mask |
| 87 | + |
| 88 | + elif self.get_point_prompts is False and self.get_box_prompts is True: # we want only boxes |
| 89 | + return None, None, bbox_list, object_mask |
| 90 | + else: |
| 91 | + assert self.get_point_prompts is False and self.get_box_prompts is False, \ |
| 92 | + "You need to request for box/point prompts or both" |
0 commit comments