Skip to content

Commit ff24b34

Browse files
Extend annotator_2d fucntionality
1 parent e2ab3f7 commit ff24b34

File tree

3 files changed

+38
-36
lines changed

3 files changed

+38
-36
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# TODO make napari imports optional so we can use micro_sam as pure library
21
import napari
32
import numpy as np
43

@@ -8,7 +7,7 @@
87
from .. import util
98
from ..visualization import compute_pca
109
from ..segment_from_prompts import segment_from_points
11-
from .util import create_prompt_menu, prompt_layer_to_points
10+
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points
1211

1312
COLOR_CYCLE = ["#00FF00", "#FF0000"]
1413

@@ -21,10 +20,9 @@ def segment_wigdet(v: Viewer):
2120
v.layers["current_object"].refresh()
2221

2322

24-
def annotator_2d(raw, embedding_path=None, show_embeddings=False):
23+
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
2524
# for access to the predictor and the image embeddings in the widgets
26-
global PREDICTOR, NEXT_ID
27-
NEXT_ID = 1
25+
global PREDICTOR
2826

2927
PREDICTOR = util.get_sam_model()
3028
image_embeddings = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
@@ -37,7 +35,10 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False):
3735
v = Viewer()
3836

3937
v.add_image(raw)
40-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="committed_objects")
38+
if segmentation_result is None:
39+
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="committed_objects")
40+
else:
41+
v.add_labels(segmentation_result, name="committed_objects")
4142
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="current_object")
4243

4344
# show the PCA of the image embeddings
@@ -71,10 +72,8 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False):
7172
v.window.add_dock_widget(prompt_widget)
7273

7374
v.window.add_dock_widget(segment_wigdet)
75+
v.window.add_dock_widget(commit_segmentation_widget)
7476

75-
#
76-
# start the viewer
77-
#
7877
#
7978
# key bindings
8079
#
@@ -83,9 +82,9 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False):
8382
def _segmet(v):
8483
segment_wigdet(v)
8584

86-
# @v.bind_key("c")
87-
# def _commit(v):
88-
# commit_widget(v)
85+
@v.bind_key("c")
86+
def _commit(v):
87+
commit_segmentation_widget(v)
8988

9089
@v.bind_key("t")
9190
def toggle_label(event=None):
@@ -103,6 +102,10 @@ def clear_prompts(v):
103102
prompts.data = []
104103
prompts.refresh()
105104

105+
#
106+
# start the viewer
107+
#
108+
106109
# clear the initial points needed for workaround
107110
clear_prompts(v)
108111
napari.run()

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .. import util
88
from ..segment_from_prompts import segment_from_mask, segment_from_points
99
from ..visualization import compute_pca
10-
from .util import create_prompt_menu, prompt_layer_to_points
10+
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points
1111

1212
COLOR_CYCLE = ["#00FF00", "#FF0000"]
1313

@@ -160,27 +160,9 @@ def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "
160160
v.layers["current_object"].refresh()
161161

162162

163-
@magicgui(call_button="Commit [C]")
164-
def commit_widget(v: Viewer):
165-
global NEXT_ID
166-
seg = v.layers["current_object"].data
167-
168-
v.layers["committed_objects"].data[seg == 1] = NEXT_ID
169-
v.layers["committed_objects"].refresh()
170-
171-
shape = v.layers["raw"].data.shape
172-
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")
173-
v.layers["current_object"].refresh()
174-
175-
v.layers["prompts"].data = []
176-
v.layers["prompts"].refresh()
177-
NEXT_ID += 1
178-
179-
180-
# TODO enable passing also an initial segmentation
181163
def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
182164
# for access to the predictor and the image embeddings in the widgets
183-
global PREDICTOR, IMAGE_EMBEDDINGS, NEXT_ID
165+
global PREDICTOR, IMAGE_EMBEDDINGS
184166
PREDICTOR = util.get_sam_model()
185167
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
186168

@@ -193,11 +175,9 @@ def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_r
193175
v.add_image(raw)
194176
if segmentation_result is None:
195177
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="committed_objects")
196-
NEXT_ID = 1
197178
else:
198179
assert segmentation_result.shape == raw.shape
199180
v.add_labels(data=segmentation_result, name="committed_objects")
200-
NEXT_ID = int(segmentation_result.max()) + 1
201181
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="current_object")
202182

203183
# show the PCA of the image embeddings
@@ -234,7 +214,7 @@ def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_r
234214
# v.bind_key("s", segment_slice_wigdet) FIXME this causes an issue with all shortcuts
235215

236216
v.window.add_dock_widget(segment_volume_widget)
237-
v.window.add_dock_widget(commit_widget)
217+
v.window.add_dock_widget(commit_segmentation_widget)
238218

239219
#
240220
# key bindings
@@ -250,7 +230,7 @@ def _seg_volume(v):
250230

251231
@v.bind_key("c")
252232
def _commit(v):
253-
commit_widget(v)
233+
commit_segmentation_widget(v)
254234

255235
@v.bind_key("t")
256236
def toggle_label(event=None):

micro_sam/sam_annotator/util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
import numpy as np
2+
3+
from magicgui import magicgui
24
from magicgui.widgets import ComboBox, Container
5+
from napari import Viewer
6+
7+
8+
@magicgui(call_button="Commit [C]")
9+
def commit_segmentation_widget(v: Viewer):
10+
seg = v.layers["current_object"].data
11+
12+
next_id = int(v.layers["committed_objects"].data.max() + 1)
13+
v.layers["committed_objects"].data[seg == 1] = next_id
14+
v.layers["committed_objects"].refresh()
15+
16+
shape = v.layers["raw"].data.shape
17+
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")
18+
v.layers["current_object"].refresh()
19+
20+
v.layers["prompts"].data = []
21+
v.layers["prompts"].refresh()
322

423

524
def create_prompt_menu(points_layer, labels):

0 commit comments

Comments
 (0)