Skip to content

Commit f351d72

Browse files
Merge branch 'master' into test-prompts
2 parents 801a330 + e220a84 commit f351d72

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

micro_sam/prompt_generators.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
from scipy.ndimage import binary_dilation
33

44

5-
class PointPromptGenerator:
6-
def __init__(self, n_positive_points, n_negative_points, dilation_strength):
5+
class PointAndBoxPromptGenerator:
6+
def __init__(self, n_positive_points, n_negative_points, dilation_strength,
7+
get_point_prompts=True, get_box_prompts=False):
78
self.n_positive_points = n_positive_points
89
self.n_negative_points = n_negative_points
910
self.dilation_strength = dilation_strength
11+
self.get_box_prompts = get_box_prompts
12+
self.get_point_prompts = get_point_prompts
13+
14+
if self.get_point_prompts is False and self.get_box_prompts is False:
15+
raise ValueError("You need to request for box/point prompts or both")
1016

1117
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
1218
"""
@@ -20,9 +26,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2026
label_list = []
2127

2228
# getting the center coordinate as the first positive point
23-
coord_list.append(tuple(map(int, center_coordinates)))
29+
coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float
2430
label_list.append(1)
2531

32+
if self.get_box_prompts:
33+
bbox_list = [bbox_coordinates]
34+
2635
object_mask = gt == gt_id + 1 # alloting a label id to obtain the coordinates of desired seeds
2736

2837
# getting the additional positive points by randomly sampling points from this mask
@@ -74,4 +83,12 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7483
coord_list.append(negative_coordinates)
7584
label_list.append(0)
7685

77-
return coord_list, label_list, None, object_mask
86+
# returns object-level masks per instance for cross-verification (TODO: fix it later)
87+
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box
88+
return coord_list, label_list, bbox_list, object_mask
89+
90+
elif self.get_point_prompts is True and self.get_box_prompts is False: # we want only points
91+
return coord_list, label_list, None, object_mask
92+
93+
elif self.get_point_prompts is False and self.get_box_prompts is True: # we want only boxes
94+
return None, None, bbox_list, object_mask

micro_sam/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def get_sam_model(device=None, model_type="vit_h", checkpoint_path=None, return_
105105

106106

107107
def _to_image(input_):
108+
# we require the input to be uint8
109+
if input_.dtype != np.dtype("uint8"):
110+
# first normalize the input to [0, 1]
111+
input_ = input_.astype("float32") - input_.min()
112+
input_ = input_ / input_.max()
113+
# then bring to [0, 255] and cast to uint8
114+
input_ = (input_ * 255).astype("uint8")
108115
if input_.ndim == 2:
109116
image = np.concatenate([input_[..., None]] * 3, axis=-1)
110117
elif input_.ndim == 3 and input_.shape[-1] == 3:

0 commit comments

Comments
 (0)