Skip to content

Commit 36079a4

Browse files
Merge pull request #9 from computational-cell-analytics/instance-seg
Improve automatic instance segmentation
2 parents 9a02fef + 3395ec7 commit 36079a4

File tree

6 files changed

+129
-22
lines changed

6 files changed

+129
-22
lines changed

development/instance_segmentation.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import napari
33

44
from elf.io import open_file
5-
from micro_sam.segment_instances import segment_from_embeddings
5+
from micro_sam.segment_instances import segment_instances_from_embeddings, segment_instances_sam
66
from micro_sam.visualization import compute_pca
77

88

@@ -11,15 +11,23 @@ def mito_segmentation():
1111
with open_file(input_path) as f:
1212
raw = f["*.png"][-1, :768, :768]
1313

14-
predictor = util.get_sam_model()
14+
predictor, sam = util.get_sam_model(return_sam=True)
15+
16+
print("Run SAM prediction ...")
17+
seg_sam = segment_instances_sam(sam, raw)
18+
1519
image_embeddings = util.precompute_image_embeddings(predictor, raw, "../examples/embeddings/embeddings-mito2d.zarr")
1620
embedding_pca = compute_pca(image_embeddings["features"])
1721

18-
seg, initial_seg = segment_from_embeddings(predictor, image_embeddings=image_embeddings, return_initial_seg=True)
22+
print("Run prediction from embeddings ...")
23+
seg, initial_seg = segment_instances_from_embeddings(
24+
predictor, image_embeddings=image_embeddings, return_initial_seg=True
25+
)
1926

2027
v = napari.Viewer()
2128
v.add_image(raw)
2229
v.add_image(embedding_pca, scale=(12, 12))
30+
v.add_labels(seg_sam)
2331
v.add_labels(seg)
2432
v.add_labels(initial_seg)
2533
napari.run()
@@ -32,21 +40,27 @@ def cell_segmentation():
3240

3341
frame = 11
3442

35-
predictor = util.get_sam_model()
43+
predictor, sam = util.get_sam_model(return_sam=True)
44+
45+
print("Run prediction from embeddings ...")
3646
image_embeddings = util.precompute_image_embeddings(
3747
predictor, timeseries, "../examples/embeddings/embeddings-ctc.zarr"
3848
)
3949
embedding_pca = compute_pca(image_embeddings["features"][frame])
4050

41-
seg, initial_seg = segment_from_embeddings(
51+
seg, initial_seg = segment_instances_from_embeddings(
4252
predictor, image_embeddings=image_embeddings, i=frame, return_initial_seg=True
4353
)
4454

55+
print("Run SAM prediction ...")
56+
seg_sam = segment_instances_sam(sam, timeseries[frame])
57+
4558
v = napari.Viewer()
4659
v.add_image(timeseries[frame])
4760
v.add_image(embedding_pca, scale=(8, 8))
4861
v.add_labels(seg)
4962
v.add_labels(initial_seg)
63+
v.add_labels(seg_sam)
5064
napari.run()
5165

5266

examples/sam_annotator_2d.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,30 @@
22
from micro_sam.sam_annotator import annotator_2d
33

44

5-
# TODO describe how to get the data and don't use hard-coded system path
5+
# TODO describe how to get the data and don't use hard-coded path
66
def livecell_annotator():
77
im = imageio.imread(
88
"/home/pape/Work/data/incu_cyte/livecell/images/livecell_test_images/A172_Phase_C7_1_01d04h00m_4.tif"
99
)
1010
embedding_path = "./embeddings/embeddings-livecell_cropped.zarr"
11-
annotator_2d(im, embedding_path, show_embeddings=True)
11+
annotator_2d(im, embedding_path, show_embeddings=False)
12+
13+
14+
# This runs interactive 2d annotation for data from the cell tracking challenge:
15+
# It uses the training data for the HeLA dataset. You can download the data from
16+
# http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip
17+
def hela_2d_annotator():
18+
im = imageio.imread("./data/DIC-C2DH-HeLa/train/01/t011.tif")
19+
embedding_path = "./embeddings/embeddings-hela2d.zarr"
20+
annotator_2d(im, embedding_path, show_embeddings=False)
1221

1322

1423
def main():
1524
# 2d annotator for livecell data
16-
livecell_annotator()
25+
# livecell_annotator()
1726

18-
# TODO
1927
# 2d annotator for cell tracking challenge hela data
20-
# hela_2d_annotator()
28+
hela_2d_annotator()
2129

2230

2331
if __name__ == "__main__":

examples/sam_annotator_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
# This runs the interactive tracking annotator for data from the cell tracking challenge:
6-
# It uses the training data for the HeLA dataset. You can download the data via
6+
# It uses the training data for the HeLA dataset. You can download the data from
77
# http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip
88
def track_ctc_data():
99
path = "./data/DIC-C2DH-HeLa/train/01"

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from napari import Viewer
66

77
from .. import util
8+
from .. import segment_instances
89
from ..visualization import project_embeddings_for_visualization
9-
from ..segment_instances import segment_from_embeddings
1010
from ..segment_from_prompts import segment_from_points
1111
from .util import (
1212
commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE
@@ -21,20 +21,26 @@ def segment_wigdet(v: Viewer):
2121
v.layers["current_object"].refresh()
2222

2323

24-
# TODO enable choosing setting the segmentation method and setting other params
25-
@magicgui(call_button="Segment All Objects")
26-
def autosegment_widget(v: Viewer):
27-
# choose if we segment with/without tiling based on the image shape
28-
seg = segment_from_embeddings(PREDICTOR, IMAGE_EMBEDDINGS)
24+
# TODO expose more parameters
25+
@magicgui(call_button="Segment All Objects", method={"choices": ["default", "sam", "embeddings"]})
26+
def autosegment_widget(v: Viewer, method: str = "default"):
27+
if method in ("default", "sam"):
28+
print("Run automatic segmentation with SAM. This can take a few minutes ...")
29+
image = v.layers["raw"].data
30+
seg = segment_instances.segment_instances_sam(SAM, image)
31+
elif method == "embeddings":
32+
seg = segment_instances.segment_instances_from_embeddings(PREDICTOR, IMAGE_EMBEDDINGS)
33+
else:
34+
raise ValueError
2935
v.layers["auto_segmentation"].data = seg
3036
v.layers["auto_segmentation"].refresh()
3137

3238

3339
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
3440
# for access to the predictor and the image embeddings in the widgets
35-
global PREDICTOR, IMAGE_EMBEDDINGS
41+
global PREDICTOR, IMAGE_EMBEDDINGS, SAM
3642

37-
PREDICTOR = util.get_sam_model()
43+
PREDICTOR, SAM = util.get_sam_model(return_sam=True)
3844
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path, ndim=2)
3945
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
4046

micro_sam/segment_instances.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import vigra
33

44
from elf.segmentation import embeddings as embed
5+
from segment_anything import SamAutomaticMaskGenerator
56
from skimage.transform import resize
7+
68
try:
79
from napari.utils import progress as tqdm
810
except ImportError:
@@ -17,8 +19,18 @@
1719
#
1820

1921

20-
# TODO implement automatic instance segmentation based on the functionalities from segment anything:
21-
# https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
22+
def segment_instances_sam(sam, image, **kwargs):
23+
segmentor = SamAutomaticMaskGenerator(sam, **kwargs)
24+
25+
image_ = util._to_image(image)
26+
masks = segmentor.generate(image_)
27+
masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
28+
29+
segmentation = np.zeros(image.shape[:2], dtype="uint32")
30+
for seg_id, mask in enumerate(masks, 1):
31+
segmentation[mask["segmentation"]] = seg_id
32+
33+
return segmentation
2234

2335

2436
#
@@ -58,7 +70,7 @@ def _refine_initial_segmentation(predictor, initial_seg, image_embeddings, i, ve
5870
# - Can we get intermediate, larger embeddings from SAM?
5971
# - Can we run the encoder in a sliding window and somehow stitch the embeddings?
6072
# - Or: run the encoder in a sliding window and stitch the initial segmentation result.
61-
def segment_from_embeddings(
73+
def segment_instances_from_embeddings(
6274
predictor, image_embeddings, size_threshold=10, i=None,
6375
offsets=[[-1, 0], [0, -1], [-3, 0], [0, -3]], distance_type="l2", bias=0.0,
6476
verbose=True, return_initial_seg=False,

test/test_segment_instances.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import unittest
2+
3+
import micro_sam.util as util
4+
import numpy as np
5+
6+
from elf.evaluation.matching import matching
7+
from skimage.draw import disk
8+
9+
10+
class TestSegmentInstances(unittest.TestCase):
11+
12+
# create an input image with three objects
13+
def _get_input(self, shape=(128, 128)):
14+
mask = np.zeros(shape, dtype="uint8")
15+
16+
def write_object(center, radius):
17+
circle = disk(center, radius, shape=shape)
18+
mask[circle] = 1
19+
20+
center = tuple(sh // 4 for sh in shape)
21+
write_object(center, radius=10)
22+
23+
center = tuple(sh // 2 for sh in shape)
24+
write_object(center, radius=9)
25+
26+
center = tuple(3 * sh // 4 for sh in shape)
27+
write_object(center, radius=11)
28+
29+
image = mask * 255
30+
return mask, image
31+
32+
def _get_model(self):
33+
predictor, sam = util.get_sam_model(model_type="vit_b", return_sam=True)
34+
return predictor, sam
35+
36+
@unittest.skip("This test takes very long.")
37+
def test_segment_instances_sam(self):
38+
from micro_sam.segment_instances import segment_instances_sam
39+
40+
mask, image = self._get_input()
41+
_, sam = self._get_model()
42+
43+
predicted = segment_instances_sam(sam, image)
44+
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
45+
46+
@unittest.skip("Needs some more debugging.")
47+
def test_segment_instances_from_embeddings(self):
48+
from micro_sam.segment_instances import segment_instances_from_embeddings
49+
50+
mask, image = self._get_input()
51+
predictor, _ = self._get_model()
52+
53+
image_embeddings = util.precompute_image_embeddings(predictor, image)
54+
util.set_precomputed(predictor, image_embeddings)
55+
56+
predicted = segment_instances_from_embeddings(predictor, image_embeddings)
57+
# import napari
58+
# v = napari.Viewer()
59+
# v.add_image(image)
60+
# v.add_labels(mask)
61+
# v.add_labels(predicted)
62+
# napari.run()
63+
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

0 commit comments

Comments
 (0)