Skip to content

Commit a0cd794

Browse files
Update 2d annotator to use automatic SAM instance segmentation by default
1 parent 5922499 commit a0cd794

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

examples/sam_annotator_2d.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,30 @@
22
from micro_sam.sam_annotator import annotator_2d
33

44

5-
# TODO describe how to get the data and don't use hard-coded system path
5+
# TODO describe how to get the data and don't use hard-coded path
66
def livecell_annotator():
77
im = imageio.imread(
88
"/home/pape/Work/data/incu_cyte/livecell/images/livecell_test_images/A172_Phase_C7_1_01d04h00m_4.tif"
99
)
1010
embedding_path = "./embeddings/embeddings-livecell_cropped.zarr"
11-
annotator_2d(im, embedding_path, show_embeddings=True)
11+
annotator_2d(im, embedding_path, show_embeddings=False)
12+
13+
14+
# This runs interactive 2d annotation for data from the cell tracking challenge:
15+
# It uses the training data for the HeLA dataset. You can download the data from
16+
# http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip
17+
def hela_2d_annotator():
18+
im = imageio.imread("./data/DIC-C2DH-HeLa/train/01/t011.tif")
19+
embedding_path = "./embeddings/embeddings-hela2d.zarr"
20+
annotator_2d(im, embedding_path, show_embeddings=False)
1221

1322

1423
def main():
1524
# 2d annotator for livecell data
16-
livecell_annotator()
25+
# livecell_annotator()
1726

18-
# TODO
1927
# 2d annotator for cell tracking challenge hela data
20-
# hela_2d_annotator()
28+
hela_2d_annotator()
2129

2230

2331
if __name__ == "__main__":

examples/sam_annotator_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
# This runs the interactive tracking annotator for data from the cell tracking challenge:
6-
# It uses the training data for the HeLA dataset. You can download the data via
6+
# It uses the training data for the HeLA dataset. You can download the data from
77
# http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip
88
def track_ctc_data():
99
path = "./data/DIC-C2DH-HeLa/train/01"

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from napari import Viewer
66

77
from .. import util
8+
from .. import segment_instances
89
from ..visualization import project_embeddings_for_visualization
9-
from ..segment_instances import segment_from_embeddings
1010
from ..segment_from_prompts import segment_from_points
1111
from .util import (
1212
commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE
@@ -21,20 +21,26 @@ def segment_wigdet(v: Viewer):
2121
v.layers["current_object"].refresh()
2222

2323

24-
# TODO enable choosing setting the segmentation method and setting other params
25-
@magicgui(call_button="Segment All Objects")
26-
def autosegment_widget(v: Viewer):
27-
# choose if we segment with/without tiling based on the image shape
28-
seg = segment_from_embeddings(PREDICTOR, IMAGE_EMBEDDINGS)
24+
# TODO expose more parameters
25+
@magicgui(call_button="Segment All Objects", method={"choices": ["default", "sam", "embeddings"]})
26+
def autosegment_widget(v: Viewer, method: str = "default"):
27+
if method in ("default", "sam"):
28+
print("Run automatic segmentation with SAM. This can take a few minutes ...")
29+
image = v.layers["raw"].data
30+
seg = segment_instances.segment_instances_sam(SAM, image)
31+
elif method == "embeddings":
32+
seg = segment_instances.segment_instances_from_embeddings(PREDICTOR, IMAGE_EMBEDDINGS)
33+
else:
34+
raise ValueError
2935
v.layers["auto_segmentation"].data = seg
3036
v.layers["auto_segmentation"].refresh()
3137

3238

3339
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
3440
# for access to the predictor and the image embeddings in the widgets
35-
global PREDICTOR, IMAGE_EMBEDDINGS
41+
global PREDICTOR, IMAGE_EMBEDDINGS, SAM
3642

37-
PREDICTOR = util.get_sam_model()
43+
PREDICTOR, SAM = util.get_sam_model(return_sam=True)
3844
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path, ndim=2)
3945
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
4046

0 commit comments

Comments
 (0)