Skip to content

Commit e220a84

Browse files
Merge pull request #13 from computational-cell-analytics/box-prompter
Update Prompted Generator - Add Box Prompts
2 parents e15251b + 25ea229 commit e220a84

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

micro_sam/prompt_generators.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
from scipy.ndimage import binary_dilation
33

44

5-
class PointPromptGenerator:
6-
def __init__(self, n_positive_points, n_negative_points, dilation_strength):
5+
class PointAndBoxPromptGenerator:
6+
def __init__(self, n_positive_points, n_negative_points, dilation_strength,
7+
get_point_prompts=True, get_box_prompts=False):
78
self.n_positive_points = n_positive_points
89
self.n_negative_points = n_negative_points
910
self.dilation_strength = dilation_strength
11+
self.get_box_prompts = get_box_prompts
12+
self.get_point_prompts = get_point_prompts
13+
14+
if self.get_point_prompts is False and self.get_box_prompts is False:
15+
raise ValueError("You need to request for box/point prompts or both")
1016

1117
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
1218
"""
@@ -20,9 +26,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2026
label_list = []
2127

2228
# getting the center coordinate as the first positive point
23-
coord_list.append(tuple(map(int, center_coordinates)))
29+
coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float
2430
label_list.append(1)
2531

32+
if self.get_box_prompts:
33+
bbox_list = [bbox_coordinates]
34+
2635
object_mask = gt == gt_id + 1 # alloting a label id to obtain the coordinates of desired seeds
2736

2837
# getting the additional positive points by randomly sampling points from this mask
@@ -52,8 +61,10 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
5261
dilated_object = binary_dilation(object_mask, iterations=self.dilation_strength)
5362
background_mask = np.zeros(gt.shape)
5463
background_mask[bbox_coordinates[0]:bbox_coordinates[2], bbox_coordinates[1]:bbox_coordinates[3]] = 1
55-
background_mask = abs(background_mask - dilated_object)
5664
background_mask = binary_dilation(background_mask, iterations=self.dilation_strength)
65+
background_mask = abs(
66+
background_mask.astype(np.float32) - dilated_object.astype(np.float32)
67+
) # casting booleans to do subtraction
5768

5869
n_negative_remaining = self.n_negative_points
5970
if n_negative_remaining > 0:
@@ -72,4 +83,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7283
coord_list.append(negative_coordinates)
7384
label_list.append(0)
7485

75-
return coord_list, label_list, None, object_mask
86+
# returns object-level masks per instance for cross-verification (TODO: fix it later)
87+
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box
88+
return coord_list, label_list, bbox_list, object_mask
89+
90+
elif self.get_point_prompts is True and self.get_box_prompts is False: # we want only points
91+
return coord_list, label_list, None, object_mask
92+
93+
elif self.get_point_prompts is False and self.get_box_prompts is True: # we want only boxes
94+
return None, None, bbox_list, object_mask

0 commit comments

Comments
 (0)