Skip to content

Commit cfc93d9

Browse files
committed
Update Arguments for Multi-Prompting - for Box/Prompt(/both)
1 parent 3abcbd0 commit cfc93d9

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

micro_sam/prompt_generators.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44

55
class PointAndBoxPromptGenerator:
6-
def __init__(self, n_positive_points, n_negative_points, dilation_strength, get_box_prompts=False):
6+
def __init__(self, n_positive_points, n_negative_points, dilation_strength,
7+
get_point_prompts=False, get_box_prompts=False):
78
self.n_positive_points = n_positive_points
89
self.n_negative_points = n_negative_points
910
self.dilation_strength = dilation_strength
1011
self.get_box_prompts = get_box_prompts
12+
self.get_point_prompts = get_point_prompts
1113

1214
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
1315
"""
@@ -77,7 +79,14 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7779
label_list.append(0)
7880

7981
# returns object-level masks per instance for cross-verification (fix it later)
80-
if self.get_box_prompts:
82+
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box
8183
return coord_list, label_list, bbox_list, object_mask
82-
else:
84+
85+
elif self.get_point_prompts is True and self.get_box_prompts is False: # we want only points
8386
return coord_list, label_list, None, object_mask
87+
88+
elif self.get_point_prompts is False and self.get_box_prompts is True: # we want only boxes
89+
return None, None, bbox_list, object_mask
90+
else:
91+
assert self.get_point_prompts is False and self.get_box_prompts is False, \
92+
"You need to request for box/point prompts or both"

0 commit comments

Comments
 (0)