Skip to content

Commit 45119dd

Browse files
Implement segment_from_box_and_points and test
1 parent 4cc9d0d commit 45119dd

File tree

2 files changed

+49
-13
lines changed

2 files changed

+49
-13
lines changed

micro_sam/segment_from_prompts.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,22 @@ def segment_from_box(
130130
return mask, scores, logits
131131
else:
132132
return mask
133+
134+
135+
def segment_from_box_and_points(
136+
predictor, box, points, labels,
137+
image_embeddings=None, i=None, original_size=None,
138+
multimask_output=False, return_all=False,
139+
):
140+
if image_embeddings is not None:
141+
util.set_precomputed(predictor, image_embeddings, i)
142+
mask, scores, logits = predictor.predict(
143+
point_coords=points[:, ::-1], # SAM has reversed XY conventions
144+
point_labels=labels,
145+
box=_process_box(box, original_size),
146+
multimask_output=multimask_output
147+
)
148+
if return_all:
149+
return mask, scores, logits
150+
else:
151+
return mask

test/test_segment_from_prompts.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,47 @@
77

88

99
class TestSegmentFromPrompts(unittest.TestCase):
10-
def _get_input(self, shape=(256, 256)):
10+
@staticmethod
11+
def _get_input(shape=(256, 256)):
1112
mask = np.zeros(shape, dtype="uint8")
1213
center = tuple(sh // 2 for sh in shape)
1314
circle = disk(center, radius=20, shape=shape)
1415
mask[circle] = 1
1516
image = mask * 255
1617
return mask, image
1718

18-
def _get_model(self, image):
19+
@staticmethod
20+
def _get_model(image):
1921
predictor = util.get_sam_model(model_type="vit_b")
2022
image_embeddings = util.precompute_image_embeddings(predictor, image)
2123
util.set_precomputed(predictor, image_embeddings)
2224
return predictor
2325

26+
# we compute the default mask and predictor once for the class
27+
# so that we don't have to precompute it every time
28+
@classmethod
29+
def setUpClass(cls):
30+
cls.mask, cls.image = cls._get_input()
31+
cls.predictor = cls._get_model(cls.image)
32+
2433
def test_segment_from_points(self):
2534
from micro_sam.segment_from_prompts import segment_from_points
2635

27-
mask, image = self._get_input()
28-
predictor = self._get_model(image)
29-
3036
points = np.array([[128, 128], [64, 64], [192, 192], [64, 192], [192, 64]])
3137
labels = np.array([1, 0, 0, 0, 0])
3238

33-
predicted = segment_from_points(predictor, points, labels)
34-
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
39+
predicted = segment_from_points(self.predictor, points, labels)
40+
self.assertGreater(util.compute_iou(self.mask, predicted), 0.9)
3541

3642
def _test_segment_from_mask(self, shape=(256, 256), use_mask=True):
3743
from micro_sam.segment_from_prompts import segment_from_mask
3844

39-
mask, image = self._get_input(shape)
40-
predictor = self._get_model(image)
45+
if shape == (256, 256):
46+
mask, image = self.mask, self.image
47+
predictor = self.predictor
48+
else:
49+
mask, image = self._get_input(shape)
50+
predictor = self._get_model(image)
4151

4252
# with mask and bounding box (default setting)
4353
if use_mask:
@@ -64,12 +74,19 @@ def test_segment_from_mask_non_square(self):
6474
def test_segment_from_box(self):
6575
from micro_sam.segment_from_prompts import segment_from_box
6676

67-
mask, image = self._get_input()
68-
predictor = self._get_model(image)
77+
box = np.array([106, 106, 150, 150])
78+
predicted = segment_from_box(self.predictor, box)
79+
self.assertGreater(util.compute_iou(self.mask, predicted), 0.9)
80+
81+
def test_segment_from_box_and_points(self):
82+
from micro_sam.segment_from_prompts import segment_from_box_and_points
6983

7084
box = np.array([106, 106, 150, 150])
71-
predicted = segment_from_box(predictor, box)
72-
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
85+
points = np.array([[128, 128], [64, 64], [192, 192], [64, 192], [192, 64]])
86+
labels = np.array([1, 0, 0, 0, 0])
87+
88+
predicted = segment_from_box_and_points(self.predictor, box, points, labels)
89+
self.assertGreater(util.compute_iou(self.mask, predicted), 0.9)
7390

7491

7592
if __name__ == "__main__":

0 commit comments

Comments
 (0)