Skip to content

Commit 25ea229

Browse files
committed
Add ValueError for prompting (box/prompt) mechanism
1 parent e56d82b commit 25ea229

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

micro_sam/prompt_generators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44

55
class PointAndBoxPromptGenerator:
66
def __init__(self, n_positive_points, n_negative_points, dilation_strength,
7-
get_point_prompts=False, get_box_prompts=False):
7+
get_point_prompts=True, get_box_prompts=False):
88
self.n_positive_points = n_positive_points
99
self.n_negative_points = n_negative_points
1010
self.dilation_strength = dilation_strength
1111
self.get_box_prompts = get_box_prompts
1212
self.get_point_prompts = get_point_prompts
1313

14+
if self.get_point_prompts is False and self.get_box_prompts is False:
15+
raise ValueError("You need to request for box/point prompts or both")
16+
1417
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
1518
"""
1619
Parameters:
@@ -80,7 +83,7 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
8083
coord_list.append(negative_coordinates)
8184
label_list.append(0)
8285

83-
# returns object-level masks per instance for cross-verification (fix it later)
86+
# returns object-level masks per instance for cross-verification (TODO: fix it later)
8487
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box
8588
return coord_list, label_list, bbox_list, object_mask
8689

@@ -89,6 +92,3 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
8992

9093
elif self.get_point_prompts is False and self.get_box_prompts is True: # we want only boxes
9194
return None, None, bbox_list, object_mask
92-
else:
93-
assert self.get_point_prompts is False and self.get_box_prompts is False, \
94-
"You need to request for box/point prompts or both"

0 commit comments

Comments
 (0)