Skip to content

Commit 3395ec7

Browse files
Add instance segmentation tests WIP
1 parent a0cd794 commit 3395ec7

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

test/test_segment_instances.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import unittest
2+
3+
import micro_sam.util as util
4+
import numpy as np
5+
6+
from elf.evaluation.matching import matching
7+
from skimage.draw import disk
8+
9+
10+
class TestSegmentInstances(unittest.TestCase):
11+
12+
# create an input image with three objects
13+
def _get_input(self, shape=(128, 128)):
14+
mask = np.zeros(shape, dtype="uint8")
15+
16+
def write_object(center, radius):
17+
circle = disk(center, radius, shape=shape)
18+
mask[circle] = 1
19+
20+
center = tuple(sh // 4 for sh in shape)
21+
write_object(center, radius=10)
22+
23+
center = tuple(sh // 2 for sh in shape)
24+
write_object(center, radius=9)
25+
26+
center = tuple(3 * sh // 4 for sh in shape)
27+
write_object(center, radius=11)
28+
29+
image = mask * 255
30+
return mask, image
31+
32+
def _get_model(self):
33+
predictor, sam = util.get_sam_model(model_type="vit_b", return_sam=True)
34+
return predictor, sam
35+
36+
@unittest.skip("This test takes very long.")
37+
def test_segment_instances_sam(self):
38+
from micro_sam.segment_instances import segment_instances_sam
39+
40+
mask, image = self._get_input()
41+
_, sam = self._get_model()
42+
43+
predicted = segment_instances_sam(sam, image)
44+
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
45+
46+
@unittest.skip("Needs some more debugging.")
47+
def test_segment_instances_from_embeddings(self):
48+
from micro_sam.segment_instances import segment_instances_from_embeddings
49+
50+
mask, image = self._get_input()
51+
predictor, _ = self._get_model()
52+
53+
image_embeddings = util.precompute_image_embeddings(predictor, image)
54+
util.set_precomputed(predictor, image_embeddings)
55+
56+
predicted = segment_instances_from_embeddings(predictor, image_embeddings)
57+
# import napari
58+
# v = napari.Viewer()
59+
# v.add_image(image)
60+
# v.add_labels(mask)
61+
# v.add_labels(predicted)
62+
# napari.run()
63+
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

0 commit comments

Comments
 (0)