Skip to content

Commit 3abcbd0

Browse files
committed
Update Prompted Generator - Add Box Prompts
1 parent ea84438 commit 3abcbd0

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

micro_sam/prompt_generators.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from scipy.ndimage import binary_dilation
33

44

5-
class PointPromptGenerator:
6-
def __init__(self, n_positive_points, n_negative_points, dilation_strength):
5+
class PointAndBoxPromptGenerator:
6+
def __init__(self, n_positive_points, n_negative_points, dilation_strength, get_box_prompts=False):
77
self.n_positive_points = n_positive_points
88
self.n_negative_points = n_negative_points
99
self.dilation_strength = dilation_strength
10+
self.get_box_prompts = get_box_prompts
1011

1112
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
1213
"""
@@ -20,9 +21,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2021
label_list = []
2122

2223
# getting the center coordinate as the first positive point
23-
coord_list.append(tuple(map(int, center_coordinates)))
24+
coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float
2425
label_list.append(1)
2526

27+
if self.get_box_prompts:
28+
bbox_list = [bbox_coordinates]
29+
2630
object_mask = gt == gt_id + 1 # alloting a label id to obtain the coordinates of desired seeds
2731

2832
# getting the additional positive points by randomly sampling points from this mask
@@ -72,4 +76,8 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7276
coord_list.append(negative_coordinates)
7377
label_list.append(0)
7478

75-
return coord_list, label_list, None, object_mask
79+
# returns object-level masks per instance for cross-verification (fix it later)
80+
if self.get_box_prompts:
81+
return coord_list, label_list, bbox_list, object_mask
82+
else:
83+
return coord_list, label_list, None, object_mask

0 commit comments

Comments
 (0)