Skip to content

Commit e56d82b

Browse files
committed
Merge branch 'box-prompter' of https://github.com/computational-cell-analytics/micro-sam into box-prompter
2 parents 5ba3d58 + cfc93d9 commit e56d82b

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

micro_sam/prompt_generators.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44

55
class PointAndBoxPromptGenerator:
6-
def __init__(self, n_positive_points, n_negative_points, dilation_strength, get_box_prompts=False):
6+
def __init__(self, n_positive_points, n_negative_points, dilation_strength,
7+
get_point_prompts=False, get_box_prompts=False):
78
self.n_positive_points = n_positive_points
89
self.n_negative_points = n_negative_points
910
self.dilation_strength = dilation_strength
1011
self.get_box_prompts = get_box_prompts
12+
self.get_point_prompts = get_point_prompts
1113

1214
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
1315
"""
@@ -79,7 +81,14 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7981
label_list.append(0)
8082

8183
# returns object-level masks per instance for cross-verification (fix it later)
82-
if self.get_box_prompts:
84+
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box
8385
return coord_list, label_list, bbox_list, object_mask
84-
else:
86+
87+
elif self.get_point_prompts is True and self.get_box_prompts is False: # we want only points
8588
return coord_list, label_list, None, object_mask
89+
90+
elif self.get_point_prompts is False and self.get_box_prompts is True: # we want only boxes
91+
return None, None, bbox_list, object_mask
92+
else:
93+
assert self.get_point_prompts is False and self.get_box_prompts is False, \
94+
"You need to request for box/point prompts or both"

0 commit comments

Comments
 (0)