Skip to content

Commit 5e63ed4

Browse files
Merge pull request #58 from computational-cell-analytics/instance-seg-plus
Implement tool for interactive instacne segmentation
2 parents 9bb4594 + b75ab58 commit 5e63ed4

File tree

12 files changed

+325
-62
lines changed

12 files changed

+325
-62
lines changed

development/annotator_2d_tiled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def annotator_with_tiling():
1616
# napari.run()
1717

1818
embedding_path = "./embeddings/embeddings-tiled.zarr"
19-
annotator_2d(im, embedding_path, tile_shape=(1024, 1024), halo=(256, 256))
19+
annotator_2d(im, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), show_embeddings=False)
2020

2121

2222
def debug():

development/annotator_3d_tiled.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def annotator_with_tiling():
66
with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f:
77
raw = f["volumes/raw/s0"][:25]
88
embedding_path = "./embeddings/embeddings-tiled_3d.zarr"
9-
annotator_3d(raw, embedding_path, tile_shape=(512, 512), halo=(64, 64))
9+
annotator_3d(raw, embedding_path, tile_shape=(512, 512), halo=(64, 64), show_embeddings=False)
1010

1111

1212
def segment_tiled():
@@ -25,8 +25,8 @@ def segment_tiled():
2525

2626

2727
def main():
28-
# annotator_with_tiling()
29-
segment_tiled()
28+
annotator_with_tiling()
29+
# segment_tiled()
3030

3131

3232
main()

examples/sam_annotator_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def track_ctc_data():
1010
example_data_directory = "./data"
1111
with open_file(str(fetch_example_data(example_data_directory)), mode="r") as f:
1212
timeseries = f["*.tif"]
13-
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-ctc.zarr")
13+
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-ctc.zarr", show_embeddings=False)
1414

1515

1616
def fetch_example_data(save_directory):

micro_sam/sam_annotator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .annotator_3d import annotator_3d
33
from .annotator_tracking import annotator_tracking
44
from .image_series_annotator import image_folder_annotator, image_series_annotator
5+
from .interactive_instance_segmentation import interactive_instance_segmentation

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,34 @@ def segment_wigdet(v: Viewer):
3939
v.layers["current_object"].refresh()
4040

4141

42-
# TODO expose more parameters:
43-
# - min initial size
44-
# - advanced params???
4542
@magicgui(call_button="Automatic Segmentation")
4643
def autosegment_widget(
47-
v: Viewer, with_background: bool = True, box_extension: float = 0.1, pred_iou_thresh: float = 0.88
44+
v: Viewer,
45+
with_background: bool = True,
46+
pred_iou_thresh: float = 0.88,
47+
stability_score_thresh: float = 0.95,
48+
min_initial_size: int = 10,
49+
use_box: bool = True,
50+
use_mask: bool = True,
51+
use_points: bool = False,
52+
box_extension: float = 0.1,
4853
):
4954
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
5055
if is_tiled:
5156
seg = segment_instances.segment_instances_from_embeddings_with_tiling(
5257
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
5358
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
59+
stability_score_thresh=stability_score_thresh,
60+
min_initial_size=min_initial_size,
61+
use_box=use_box, use_points=use_points, use_mask=use_mask,
5462
)
5563
else:
5664
seg = segment_instances.segment_instances_from_embeddings(
5765
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
5866
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
67+
stability_score_thresh=stability_score_thresh,
68+
min_initial_size=min_initial_size,
69+
use_box=use_box, use_points=use_points, use_mask=use_mask,
5970
)
6071
v.layers["auto_segmentation"].data = seg
6172
v.layers["auto_segmentation"].refresh()
@@ -90,11 +101,9 @@ def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings):
90101
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object")
91102

92103
# show the PCA of the image embeddings
93-
if show_embeddings and tile_shape is None:
94-
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], shape)
104+
if show_embeddings:
105+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
95106
v.add_image(embedding_vis, name="embeddings", scale=scale)
96-
elif show_embeddings:
97-
warnings.warn("Embeddings cannot be shown for tiled prediction.")
98107

99108
labels = ["positive", "negative"]
100109
prompts = v.add_points(
@@ -205,8 +214,6 @@ def annotator_2d(
205214

206215

207216
def main():
208-
import warnings
209-
210217
parser = _initialize_parser(description="Run interactive segmentation for an image.")
211218
args = parser.parse_args()
212219
raw = util.load_image_data(args.input, ndim=2, key=args.key)

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def annotator_3d(
215215

216216
# show the PCA of the image embeddings
217217
if show_embeddings:
218-
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], raw.shape)
218+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
219219
v.add_image(embedding_vis, name="embeddings", scale=scale)
220220

221221
labels = ["positive", "negative"]

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def annotator_tracking(
381381

382382
# show the PCA of the image embeddings
383383
if show_embeddings:
384-
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], raw.shape)
384+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
385385
v.add_image(embedding_vis, name="embeddings", scale=scale)
386386

387387
#
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import napari
2+
import numpy as np
3+
from napari import Viewer
4+
from magicgui import magicgui
5+
6+
from .annotator_2d import _get_shape
7+
from .util import _initialize_parser
8+
from ..import util
9+
from ..import segment_instances
10+
from ..visualization import project_embeddings_for_visualization
11+
12+
13+
@magicgui(call_button="Automatic Segmentation [S]")
14+
def autosegment_widget(
15+
v: Viewer,
16+
pred_iou_thresh: float = 0.88,
17+
stability_score_thresh: float = 0.95,
18+
min_initial_size: int = 10,
19+
box_extension: float = 0.1,
20+
with_background: bool = True,
21+
use_box: bool = True,
22+
use_mask: bool = True,
23+
use_points: bool = False,
24+
):
25+
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
26+
if is_tiled:
27+
seg, initial_seg = segment_instances.segment_instances_from_embeddings_with_tiling(
28+
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
29+
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
30+
stability_score_thresh=stability_score_thresh,
31+
min_initial_size=min_initial_size, return_initial_segmentation=True, verbose=2,
32+
use_box=use_box, use_mask=use_mask, use_points=use_points,
33+
)
34+
else:
35+
seg, initial_seg = segment_instances.segment_instances_from_embeddings(
36+
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
37+
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
38+
stability_score_thresh=stability_score_thresh,
39+
min_initial_size=min_initial_size, return_initial_segmentation=True, verbose=2,
40+
use_box=use_box, use_mask=use_mask, use_points=use_points,
41+
)
42+
43+
v.layers["auto_segmentation"].data = seg
44+
v.layers["auto_segmentation"].refresh()
45+
46+
v.layers["initial_segmentation"].data = initial_seg
47+
v.layers["initial_segmentation"].refresh()
48+
49+
50+
def interactive_instance_segmentation(
51+
raw, embedding_path=None, model_type="vit_h", tile_shape=None, halo=None, checkpoint=None,
52+
):
53+
"""Visualizing and debugging automatic instance segmentation.
54+
"""
55+
global PREDICTOR, IMAGE_EMBEDDINGS
56+
57+
PREDICTOR = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
58+
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
59+
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
60+
)
61+
62+
shape = _get_shape(raw)
63+
64+
v = napari.Viewer()
65+
66+
v.add_image(raw)
67+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
68+
v.add_image(embedding_vis, name="embeddings", scale=scale, visible=False)
69+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="initial_segmentation")
70+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation")
71+
72+
v.window.add_dock_widget(autosegment_widget)
73+
74+
napari.run()
75+
76+
77+
def main():
78+
parser = _initialize_parser(
79+
description="Run the automatic instance segmentation in interactive mode"
80+
"to determine the optimal segmentation parameters.",
81+
with_segmentation_result=False, with_show_embeddings=False,
82+
)
83+
parser.add_argument(
84+
"-c", "--checkpoint", default=None,
85+
help="Path to alternative checkpoint instead of the standard model."
86+
)
87+
args = parser.parse_args()
88+
raw = util.load_image_data(args.input, ndim=2, key=args.key)
89+
interactive_instance_segmentation(
90+
raw, args.embedding_path, args.model_type, args.tile_shape, args.halo, args.checkpoint
91+
)

micro_sam/sam_annotator/util.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def toggle_label(prompts):
323323
prompts.refresh_colors()
324324

325325

326-
def _initialize_parser(description, with_segmentation_result=True):
326+
def _initialize_parser(description, with_segmentation_result=True, with_show_embeddings=True):
327327
parser = argparse.ArgumentParser(description=description)
328328

329329
parser.add_argument(
@@ -336,6 +336,7 @@ def _initialize_parser(description, with_segmentation_result=True):
336336
help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
337337
"for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
338338
)
339+
339340
parser.add_argument(
340341
"-e", "--embedding_path",
341342
help="The filepath for saving/loading the pre-computed image embeddings. "
@@ -356,10 +357,11 @@ def _initialize_parser(description, with_segmentation_result=True):
356357
help="The key for opening the segmentation data. Same rules as for 'key' apply."
357358
)
358359

359-
parser.add_argument(
360-
"--show_embeddings", action="store_true",
361-
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
362-
)
360+
if with_show_embeddings:
361+
parser.add_argument(
362+
"--show_embeddings", action="store_true",
363+
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
364+
)
363365
parser.add_argument(
364366
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
365367
)

0 commit comments

Comments
 (0)