Skip to content

Commit 618cfb6

Browse files
Add tests for segment_from_prompts
1 parent 5c53672 commit 618cfb6

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

test/test_segment_from_prompts.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
3+
import micro_sam.util as util
4+
import numpy as np
5+
6+
from skimage.draw import disk
7+
8+
9+
class TestSegmentFromPrompts(unittest.TestCase):
10+
def _get_input(self):
11+
shape = (256, 256)
12+
mask = np.zeros(shape, dtype="uint8")
13+
circle = disk((128, 128), radius=20, shape=shape)
14+
mask[circle] = 1
15+
image = mask * 255
16+
return mask, image
17+
18+
def _get_model(self, image):
19+
predictor = util.get_sam_model(model_type="vit_b")
20+
image_embeddings = util.precompute_image_embeddings(predictor, image)
21+
util.set_precomputed(predictor, image_embeddings)
22+
return predictor
23+
24+
def test_segment_from_points(self):
25+
from micro_sam.segment_from_prompts import segment_from_points
26+
27+
mask, image = self._get_input()
28+
predictor = self._get_model(image)
29+
30+
points = np.array([[128, 128], [64, 64], [192, 192], [64, 192], [192, 64]])
31+
labels = np.array([1, 0, 0, 0, 0])
32+
33+
predicted = segment_from_points(predictor, points, labels)
34+
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
35+
36+
def test_segment_from_mask(self):
37+
from micro_sam.segment_from_prompts import segment_from_mask
38+
39+
mask, image = self._get_input()
40+
predictor = self._get_model(image)
41+
42+
# with mask and bounding box (default setting)
43+
predicted = segment_from_mask(predictor, mask)
44+
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
45+
46+
# with bounding box (default setting)
47+
predicted = segment_from_mask(predictor, mask, use_mask=False, use_box=True)
48+
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
49+
50+
# with bounding box (default setting)
51+
predicted = segment_from_mask(predictor, mask, use_mask=True, use_box=False)
52+
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)