Skip to content

Commit 6eed43f

Browse files
Merge pull request #2 from computational-cell-analytics/embedding-instance-seg
Add initial version of embedding based instance segmentation
2 parents 9b3c4da + e44ce82 commit 6eed43f

File tree

6 files changed

+223
-20
lines changed

6 files changed

+223
-20
lines changed

examples/instance_segmentation.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import micro_sam.util as util
2+
import napari
3+
4+
from elf.io import open_file
5+
from micro_sam.segment_instances import segment_from_embeddings
6+
from micro_sam.visualization import compute_pca
7+
8+
9+
def mito_segmentation():
10+
input_path = "./data/Lucchi++/Test_In"
11+
with open_file(input_path) as f:
12+
raw = f["*.png"][-1, :768, :768]
13+
14+
predictor = util.get_sam_model()
15+
image_embeddings = util.precompute_image_embeddings(predictor, raw, "./embeddings/embeddings-mito2d.zarr")
16+
embedding_pca = compute_pca(image_embeddings["features"])
17+
18+
seg, initial_seg = segment_from_embeddings(predictor, image_embeddings=image_embeddings, return_initial_seg=True)
19+
20+
v = napari.Viewer()
21+
v.add_image(raw)
22+
v.add_image(embedding_pca, scale=(12, 12))
23+
v.add_labels(seg)
24+
v.add_labels(initial_seg)
25+
napari.run()
26+
27+
28+
def cell_segmentation():
29+
path = "./DIC-C2DH-HeLa/train/01"
30+
with open_file(path, mode="r") as f:
31+
timeseries = f["*.tif"][:50]
32+
33+
frame = 11
34+
35+
predictor = util.get_sam_model()
36+
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, "./embeddings/embeddings-ctc.zarr")
37+
embedding_pca = compute_pca(image_embeddings["features"][frame])
38+
39+
seg, initial_seg = segment_from_embeddings(
40+
predictor, image_embeddings=image_embeddings, i=frame, return_initial_seg=True
41+
)
42+
43+
v = napari.Viewer()
44+
v.add_image(timeseries[frame])
45+
v.add_image(embedding_pca, scale=(8, 8))
46+
v.add_labels(seg)
47+
v.add_labels(initial_seg)
48+
napari.run()
49+
50+
51+
def main():
52+
# automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py')
53+
# mito_segmentation()
54+
55+
# automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py')
56+
cell_segmentation()
57+
58+
59+
if __name__ == "__main__":
60+
main()

examples/sam_annotator_2d.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@
22
from micro_sam.sam_annotator import annotator_2d
33

44

5-
def main():
5+
# TODO describe how to get the data and don't use hard-coded system path
6+
def livecell_annotator():
67
im = imageio.imread(
78
"/home/pape/Work/data/incu_cyte/livecell/images/livecell_test_images/A172_Phase_C7_1_01d04h00m_4.tif"
89
)
910
embedding_path = "./embeddings/embeddings-livecell_cropped.zarr"
10-
annotator_2d(im, embedding_path, show_embeddings=False)
11+
annotator_2d(im, embedding_path, show_embeddings=True)
12+
13+
14+
def main():
15+
# 2d annotator for livecell data
16+
# livecell_annotator()
17+
18+
# 2d annotator for cell tracking challenge hela data
19+
hela_2d_annotator()
1120

1221

1322
if __name__ == "__main__":

examples/sam_annotator_tracking.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,40 @@
11
from glob import glob
22

3-
import h5py
43
import numpy as np
4+
from elf.io import open_file
55
from micro_sam.sam_annotator import annotator_tracking
66

77

8-
def main():
8+
def track_incucyte_data():
99
pattern = "/home/pape/Work/data/incu_cyte/carmello/videos/MiaPaCa_flat_B3-3_registered/image-*"
1010
paths = glob(pattern)
1111
paths.sort()
1212

1313
timeseries = []
1414
for p in paths[:45]:
15-
with h5py.File(p) as f:
15+
with open_file(p, mode="r") as f:
1616
timeseries.append(f["phase-contrast"][:])
1717
timeseries = np.stack(timeseries)
1818

1919
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-tracking.zarr", show_embeddings=False)
2020

2121

22+
# TODO describe how to get the data from CTC
23+
def track_ctc_data():
24+
path = "./data/DIC-C2DH-HeLa/train/01"
25+
with open_file(path, mode="r") as f:
26+
timeseries = f["*.tif"][:50]
27+
28+
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-ctc.zarr")
29+
30+
31+
def main():
32+
# private data used for initial tests
33+
# track_incucyte_data()
34+
35+
# data from the cell tracking challenges
36+
track_ctc_data()
37+
38+
2239
if __name__ == "__main__":
2340
main()

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .. import util
88
from ..visualization import project_embeddings_for_visualization
9+
from ..segment_instances import segment_from_embeddings
910
from ..segment_from_prompts import segment_from_points
1011
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points
1112

@@ -20,13 +21,22 @@ def segment_wigdet(v: Viewer):
2021
v.layers["current_object"].refresh()
2122

2223

24+
# TODO enable choosing setting the segmentation method and setting other params
25+
@magicgui(call_button="Segment All Objects")
26+
def autosegment_widget(v: Viewer):
27+
# choose if we segment with/without tiling based on the image shape
28+
seg = segment_from_embeddings(PREDICTOR, IMAGE_EMBEDDINGS)
29+
v.layers["auto_segmentation"].data = seg
30+
v.layers["auto_segmentation"].refresh()
31+
32+
2333
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
2434
# for access to the predictor and the image embeddings in the widgets
25-
global PREDICTOR
35+
global PREDICTOR, IMAGE_EMBEDDINGS
2636

2737
PREDICTOR = util.get_sam_model()
28-
image_embeddings = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
29-
util.set_precomputed(PREDICTOR, image_embeddings)
38+
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
39+
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
3040

3141
#
3242
# initialize the viewer and add layers
@@ -35,6 +45,7 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
3545
v = Viewer()
3646

3747
v.add_image(raw)
48+
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="auto_segmentation")
3849
if segmentation_result is None:
3950
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="committed_objects")
4051
else:
@@ -43,7 +54,7 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
4354

4455
# show the PCA of the image embeddings
4556
if show_embeddings:
46-
embedding_vis, scale = project_embeddings_for_visualization(image_embeddings["features"], raw.shape)
57+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], raw.shape)
4758
v.add_image(embedding_vis, name="embeddings", scale=scale)
4859

4960
labels = ["positive", "negative"]
@@ -65,11 +76,12 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
6576
# add the widgets
6677
#
6778

68-
# TODO add (optional) auto-segmentation functionality
69-
7079
prompt_widget = create_prompt_menu(prompts, labels)
7180
v.window.add_dock_widget(prompt_widget)
7281

82+
# (optional) auto-segmentation functionality
83+
v.window.add_dock_widget(autosegment_widget)
84+
7385
v.window.add_dock_widget(segment_wigdet)
7486
v.window.add_dock_widget(commit_segmentation_widget)
7587

micro_sam/sam_annotator/util.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,23 @@
77
from ..segment_from_prompts import segment_from_points
88

99

10-
@magicgui(call_button="Commit [C]")
11-
def commit_segmentation_widget(v: Viewer):
12-
seg = v.layers["current_object"].data
10+
@magicgui(call_button="Commit [C]", layer={"choices": ["current_object", "auto_segmentation"]})
11+
def commit_segmentation_widget(v: Viewer, layer: str = "current_object"):
12+
seg = v.layers[layer].data
1313

14-
next_id = int(v.layers["committed_objects"].data.max() + 1)
15-
v.layers["committed_objects"].data[seg == 1] = next_id
14+
id_offset = int(v.layers["committed_objects"].data.max())
15+
mask = seg != 0
16+
17+
v.layers["committed_objects"].data[mask] = (seg[mask] + id_offset)
1618
v.layers["committed_objects"].refresh()
1719

1820
shape = v.layers["raw"].data.shape
19-
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")
20-
v.layers["current_object"].refresh()
21+
v.layers[layer].data = np.zeros(shape, dtype="uint32")
22+
v.layers[layer].refresh()
2123

22-
v.layers["prompts"].data = []
23-
v.layers["prompts"].refresh()
24+
if layer == "current_object":
25+
v.layers["prompts"].data = []
26+
v.layers["prompts"].refresh()
2427

2528

2629
def create_prompt_menu(points_layer, labels):

micro_sam/segment_instances.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
import vigra
3+
4+
from elf.segmentation import embeddings as embed
5+
from skimage.transform import resize
6+
try:
7+
from napari.utils import progress as tqdm
8+
except ImportError:
9+
from tqdm import tqdm
10+
11+
from . import util
12+
from .segment_from_prompts import segment_from_mask
13+
14+
15+
#
16+
# Original SegmentAnything instance segmentation functionality
17+
#
18+
19+
20+
# TODO implement automatic instance segmentation based on the functionalities from segment anything:
21+
# https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
22+
23+
24+
#
25+
# Instance segmentation from embeddings
26+
#
27+
28+
29+
def _refine_initial_segmentation(predictor, initial_seg, image_embeddings, i, verbose):
30+
util.set_precomputed(predictor, image_embeddings, i)
31+
32+
original_size = image_embeddings["original_size"]
33+
seg = np.zeros(original_size, dtype="uint32")
34+
35+
seg_ids = np.unique(initial_seg)
36+
# TODO be smarter for overlapping masks, (use automatic_mask_generation from SAM as template)
37+
for seg_id in tqdm(seg_ids[1:], disable=not verbose, desc="Refine masks for automatic instance segmentation"):
38+
mask = (initial_seg == seg_id)
39+
assert mask.shape == (256, 256)
40+
refined = segment_from_mask(predictor, mask, original_size=original_size).squeeze()
41+
assert refined.shape == seg.shape
42+
seg[refined.squeeze()] = seg_id
43+
44+
# import napari
45+
# v = napari.Viewer()
46+
# v.add_image(mask)
47+
# v.add_labels(refined)
48+
# napari.run()
49+
50+
return seg
51+
52+
53+
# This is a first prototype for generating automatic instance segmentations from the image embeddings
54+
# predicted by the segment anything image encoder.
55+
56+
# Main challenge: the larger the image the worse this will get because of the fixed embedding size.
57+
# Ideas:
58+
# - Can we get intermediate, larger embeddings from SAM?
59+
# - Can we run the encoder in a sliding window and somehow stitch the embeddings?
60+
# - Or: run the encoder in a sliding window and stitch the initial segmentation result.
61+
def segment_from_embeddings(
62+
predictor, image_embeddings, size_threshold=10, i=None,
63+
offsets=[[-1, 0], [0, -1], [-3, 0], [0, -3]], distance_type="l2", bias=0.0,
64+
verbose=True, return_initial_seg=False,
65+
):
66+
util.set_precomputed(predictor, image_embeddings, i)
67+
68+
embeddings = predictor.get_image_embedding().squeeze().cpu().numpy()
69+
assert embeddings.shape == (256, 64, 64), f"{embeddings.shape}"
70+
initial_seg = embed.segment_embeddings_mws(
71+
embeddings, distance_type=distance_type, offsets=offsets, bias=bias
72+
).astype("uint32")
73+
assert initial_seg.shape == (64, 64), f"{initial_seg.shape}"
74+
75+
# filter out small objects
76+
seg_ids, sizes = np.unique(initial_seg, return_counts=True)
77+
initial_seg[np.isin(initial_seg, seg_ids[sizes < size_threshold])] = 0
78+
vigra.analysis.relabelConsecutive(initial_seg, out=initial_seg)
79+
80+
# resize to 256 x 256, which is the mask input expected by SAM
81+
initial_seg = resize(
82+
initial_seg, (256, 256), order=0, preserve_range=True, anti_aliasing=False
83+
).astype(initial_seg.dtype)
84+
seg = _refine_initial_segmentation(predictor, initial_seg, image_embeddings, i, verbose)
85+
86+
if return_initial_seg:
87+
initial_seg = resize(
88+
initial_seg, seg.shape, order=0, preserve_range=True, anti_aliasing=False
89+
).astype(seg.dtype)
90+
return seg, initial_seg
91+
else:
92+
return seg
93+
94+
95+
# TODO
96+
def segment_from_embeddings_with_tiling(
97+
predictor, image, image_embeddings, tile_shape=(256, 256), tile_overlap=(32, 32),
98+
size_threshold=10, i=None,
99+
offsets=[[-1, 0], [0, -1], [-3, 0], [0, -3]], distance_type="l2", bias=0.0,
100+
verbose=True, return_initial_seg=False,
101+
):
102+
pass

0 commit comments

Comments
 (0)