Skip to content

Commit b4c316c

Browse files
Implement image series annotator as example app
1 parent 6eed43f commit b4c316c

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Example for a small application implemented using napari and the micro_sam library:
2+
# Iterate over a series of images in a folder and provide annotations with SAM.
3+
4+
import os
5+
from glob import glob
6+
7+
import imageio
8+
import micro_sam.util as util
9+
import napari
10+
import numpy as np
11+
12+
from magicgui import magicgui
13+
from micro_sam.segment_from_prompts import segment_from_points
14+
from micro_sam.sam_annotator.util import create_prompt_menu, prompt_layer_to_points
15+
from napari import Viewer
16+
17+
18+
@magicgui(call_button="Segment Object [S]")
19+
def segment_wigdet(v: Viewer):
20+
points, labels = prompt_layer_to_points(v.layers["prompts"])
21+
seg = segment_from_points(PREDICTOR, points, labels)
22+
v.layers["segmented_object"].data = seg.squeeze()
23+
v.layers["segmented_object"].refresh()
24+
25+
26+
def image_series_annotator(image_paths, embedding_save_path, output_folder):
27+
global PREDICTOR
28+
29+
os.makedirs(output_folder, exist_ok=True)
30+
31+
# get the sam predictor and precompute the image embeddings
32+
PREDICTOR = util.get_sam_model()
33+
images = np.stack([imageio.imread(p) for p in image_paths])
34+
image_embeddings = util.precompute_image_embeddings(PREDICTOR, images, save_path=embedding_save_path)
35+
util.set_precomputed(PREDICTOR, image_embeddings, i=0)
36+
37+
v = napari.Viewer()
38+
39+
# add the first image
40+
next_image_id = 0
41+
v.add_image(images[0], name="image")
42+
43+
# add a layer for the segmented object
44+
v.add_labels(data=np.zeros(images.shape[1:], dtype="uint32"), name="segmented_object")
45+
46+
# create the point layer for the sam prompts and add the widget for toggling the points
47+
labels = ["positive", "negative"]
48+
prompts = v.add_points(
49+
data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround
50+
name="prompts",
51+
properties={"label": labels},
52+
edge_color="label",
53+
edge_color_cycle=["green", "red"],
54+
symbol="o",
55+
face_color="transparent",
56+
edge_width=0.5,
57+
size=12,
58+
ndim=2,
59+
)
60+
prompts.data = []
61+
prompts.edge_color_mode = "cycle"
62+
prompt_widget = create_prompt_menu(prompts, labels)
63+
v.window.add_dock_widget(prompt_widget)
64+
65+
# toggle the points between positive / negative
66+
@v.bind_key("t")
67+
def toggle_label(event=None):
68+
# get the currently selected label
69+
current_properties = prompts.current_properties
70+
current_label = current_properties["label"][0]
71+
new_label = "negative" if current_label == "positive" else "positive"
72+
current_properties["label"] = np.array([new_label])
73+
prompts.current_properties = current_properties
74+
prompts.refresh()
75+
prompts.refresh_colors()
76+
77+
# bind the segmentation to a key 's'
78+
@v.bind_key("s")
79+
def _segmet(v):
80+
segment_wigdet(v)
81+
82+
#
83+
# the functionality for saving segmentations and going to the next image
84+
#
85+
86+
def _save_segmentation(seg, output_folder, image_path):
87+
fname = os.path.basename(image_path)
88+
save_path = os.path.join(output_folder, os.path.splitext(fname)[0] + ".tif")
89+
imageio.imwrite(save_path, seg)
90+
91+
def _next(v):
92+
nonlocal next_image_id
93+
v.layers["image"].data = images[next_image_id]
94+
util.set_precomputed(PREDICTOR, image_embeddings, i=next_image_id)
95+
96+
v.layers["segmented_object"].data = np.zeros(images[0].shape, dtype="uint32")
97+
v.layers["prompts"].data = []
98+
99+
next_image_id += 1
100+
if next_image_id >= images.shape[0]:
101+
print("Last image!")
102+
103+
@v.bind_key("n")
104+
def next_image(v):
105+
seg = v.layers["segmented_object"].data
106+
if seg.max() == 0:
107+
print("This image has not been segmented yet, doing nothing!")
108+
return
109+
110+
_save_segmentation(seg, output_folder, image_paths[next_image_id - 1])
111+
_next(v)
112+
113+
napari.run()
114+
115+
116+
# this uses data from the cell tracking challenge as example data
117+
# see 'sam_annotator_tracking' for examples
118+
def main():
119+
image_paths = sorted(glob("./data/DIC-C2DH-HeLa/train/01/*.tif"))[:50]
120+
image_series_annotator(image_paths, "./embeddings/embeddings-ctc.zarr", "segmented-series")
121+
122+
123+
if __name__ == "__main__":
124+
main()

0 commit comments

Comments
 (0)