|
| 1 | +# Example for a small application implemented using napari and the micro_sam library: |
| 2 | +# Iterate over a series of images in a folder and provide annotations with SAM. |
| 3 | + |
| 4 | +import os |
| 5 | +from glob import glob |
| 6 | + |
| 7 | +import imageio |
| 8 | +import micro_sam.util as util |
| 9 | +import napari |
| 10 | +import numpy as np |
| 11 | + |
| 12 | +from magicgui import magicgui |
| 13 | +from micro_sam.segment_from_prompts import segment_from_points |
| 14 | +from micro_sam.sam_annotator.util import create_prompt_menu, prompt_layer_to_points |
| 15 | +from napari import Viewer |
| 16 | + |
| 17 | + |
| 18 | +@magicgui(call_button="Segment Object [S]") |
| 19 | +def segment_wigdet(v: Viewer): |
| 20 | + points, labels = prompt_layer_to_points(v.layers["prompts"]) |
| 21 | + seg = segment_from_points(PREDICTOR, points, labels) |
| 22 | + v.layers["segmented_object"].data = seg.squeeze() |
| 23 | + v.layers["segmented_object"].refresh() |
| 24 | + |
| 25 | + |
| 26 | +def image_series_annotator(image_paths, embedding_save_path, output_folder): |
| 27 | + global PREDICTOR |
| 28 | + |
| 29 | + os.makedirs(output_folder, exist_ok=True) |
| 30 | + |
| 31 | + # get the sam predictor and precompute the image embeddings |
| 32 | + PREDICTOR = util.get_sam_model() |
| 33 | + images = np.stack([imageio.imread(p) for p in image_paths]) |
| 34 | + image_embeddings = util.precompute_image_embeddings(PREDICTOR, images, save_path=embedding_save_path) |
| 35 | + util.set_precomputed(PREDICTOR, image_embeddings, i=0) |
| 36 | + |
| 37 | + v = napari.Viewer() |
| 38 | + |
| 39 | + # add the first image |
| 40 | + next_image_id = 0 |
| 41 | + v.add_image(images[0], name="image") |
| 42 | + |
| 43 | + # add a layer for the segmented object |
| 44 | + v.add_labels(data=np.zeros(images.shape[1:], dtype="uint32"), name="segmented_object") |
| 45 | + |
| 46 | + # create the point layer for the sam prompts and add the widget for toggling the points |
| 47 | + labels = ["positive", "negative"] |
| 48 | + prompts = v.add_points( |
| 49 | + data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround |
| 50 | + name="prompts", |
| 51 | + properties={"label": labels}, |
| 52 | + edge_color="label", |
| 53 | + edge_color_cycle=["green", "red"], |
| 54 | + symbol="o", |
| 55 | + face_color="transparent", |
| 56 | + edge_width=0.5, |
| 57 | + size=12, |
| 58 | + ndim=2, |
| 59 | + ) |
| 60 | + prompts.data = [] |
| 61 | + prompts.edge_color_mode = "cycle" |
| 62 | + prompt_widget = create_prompt_menu(prompts, labels) |
| 63 | + v.window.add_dock_widget(prompt_widget) |
| 64 | + |
| 65 | + # toggle the points between positive / negative |
| 66 | + @v.bind_key("t") |
| 67 | + def toggle_label(event=None): |
| 68 | + # get the currently selected label |
| 69 | + current_properties = prompts.current_properties |
| 70 | + current_label = current_properties["label"][0] |
| 71 | + new_label = "negative" if current_label == "positive" else "positive" |
| 72 | + current_properties["label"] = np.array([new_label]) |
| 73 | + prompts.current_properties = current_properties |
| 74 | + prompts.refresh() |
| 75 | + prompts.refresh_colors() |
| 76 | + |
| 77 | + # bind the segmentation to a key 's' |
| 78 | + @v.bind_key("s") |
| 79 | + def _segmet(v): |
| 80 | + segment_wigdet(v) |
| 81 | + |
| 82 | + # |
| 83 | + # the functionality for saving segmentations and going to the next image |
| 84 | + # |
| 85 | + |
| 86 | + def _save_segmentation(seg, output_folder, image_path): |
| 87 | + fname = os.path.basename(image_path) |
| 88 | + save_path = os.path.join(output_folder, os.path.splitext(fname)[0] + ".tif") |
| 89 | + imageio.imwrite(save_path, seg) |
| 90 | + |
| 91 | + def _next(v): |
| 92 | + nonlocal next_image_id |
| 93 | + v.layers["image"].data = images[next_image_id] |
| 94 | + util.set_precomputed(PREDICTOR, image_embeddings, i=next_image_id) |
| 95 | + |
| 96 | + v.layers["segmented_object"].data = np.zeros(images[0].shape, dtype="uint32") |
| 97 | + v.layers["prompts"].data = [] |
| 98 | + |
| 99 | + next_image_id += 1 |
| 100 | + if next_image_id >= images.shape[0]: |
| 101 | + print("Last image!") |
| 102 | + |
| 103 | + @v.bind_key("n") |
| 104 | + def next_image(v): |
| 105 | + seg = v.layers["segmented_object"].data |
| 106 | + if seg.max() == 0: |
| 107 | + print("This image has not been segmented yet, doing nothing!") |
| 108 | + return |
| 109 | + |
| 110 | + _save_segmentation(seg, output_folder, image_paths[next_image_id - 1]) |
| 111 | + _next(v) |
| 112 | + |
| 113 | + napari.run() |
| 114 | + |
| 115 | + |
| 116 | +# this uses data from the cell tracking challenge as example data |
| 117 | +# see 'sam_annotator_tracking' for examples |
| 118 | +def main(): |
| 119 | + image_paths = sorted(glob("./data/DIC-C2DH-HeLa/train/01/*.tif"))[:50] |
| 120 | + image_series_annotator(image_paths, "./embeddings/embeddings-ctc.zarr", "segmented-series") |
| 121 | + |
| 122 | + |
| 123 | +if __name__ == "__main__": |
| 124 | + main() |
0 commit comments