Skip to content

Commit 1e714d2

Browse files
Refactor segmentation from prompts in annotators and support box prompts in 3d annotator
1 parent c91993d commit 1e714d2

File tree

3 files changed

+170
-86
lines changed

3 files changed

+170
-86
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 13 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,67 +7,24 @@
77
from .. import util
88
from .. import segment_instances
99
from ..visualization import project_embeddings_for_visualization
10-
from ..segment_from_prompts import segment_from_box, segment_from_box_and_points, segment_from_points
1110
from .util import (
12-
commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE
11+
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
12+
prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation, toggle_label, LABEL_COLOR_CYCLE
1313
)
1414

1515

1616
@magicgui(call_button="Segment Object [S]")
1717
def segment_wigdet(v: Viewer):
18-
# get the current point prompts
18+
# get the current box and point prompts
19+
boxes = prompt_layer_to_boxes(v.layers["box_prompts"])
1920
points, labels = prompt_layer_to_points(v.layers["prompts"])
20-
assert len(points) == len(labels)
21-
have_points = len(points) > 0
22-
23-
# get the current box prompts
24-
box_layer = v.layers["box_prompts"]
25-
have_boxes = box_layer.nshapes > 0
26-
27-
# segment only with points
28-
if have_points and not have_boxes:
29-
seg = segment_from_points(PREDICTOR, points, labels).squeeze()
30-
31-
# segment only with boxes
32-
elif not have_points and have_boxes:
33-
shape = v.layers["current_object"].data.shape
34-
seg = np.zeros(shape, dtype="uint32")
35-
36-
seg_id = 1
37-
for prompt_id in range(box_layer.nshapes):
38-
shape_type = box_layer.shape_type[prompt_id]
39-
40-
# for now we only support segmentation from rectangles.
41-
# supporting other shapes would be possible by casting the shape to a mask
42-
# and then segmenting from mask and bounding box.
43-
# but for this we need to fix issue with resizing the mask for non-square shapes.
44-
if shape_type != "rectangle":
45-
print(f"You have provided a {shape_type} shape.")
46-
print("We currently only support rectangle shapes for prompts and this prompt will be skipped.")
47-
continue
48-
49-
box = box_layer.data[prompt_id]
50-
prompt_box = np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()])
51-
mask = segment_from_box(PREDICTOR, prompt_box).squeeze()
52-
seg[mask] = seg_id
53-
seg_id += 1
54-
55-
# segment with points and box (currently only one box supported)
56-
elif have_points and have_boxes:
57-
if box_layer.nshapes > 1:
58-
print("You have provided point prompts and more than one box prompt.")
59-
print("This setting is currently not supported.")
60-
print("When providing both points and prompts you can only segment one object at a time.")
61-
return
62-
63-
box = box_layer.data[0]
64-
prompt_box = np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()])
65-
seg = segment_from_box_and_points(PREDICTOR, prompt_box, points, labels).squeeze()
66-
67-
# no prompts were given, skip segmentation
68-
else:
69-
print("You haven't given any prompts.")
70-
print("Please provide point and/or box prompts.")
21+
22+
shape = v.layers["current_object"].data.shape
23+
seg = prompt_segmentation(PREDICTOR, points, labels, boxes, shape, multiple_box_prompts=True)
24+
25+
# no prompts were given or prompts were invalid, skip segmentation
26+
if seg is None:
27+
print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
7128
return
7229

7330
v.layers["current_object"].data = seg
@@ -138,7 +95,7 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
13895
)
13996
prompts.edge_color_mode = "cycle"
14097

141-
box_prompts = v.add_shapes(
98+
v.add_shapes(
14299
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts"
143100
)
144101

@@ -173,10 +130,7 @@ def _toggle_label(event=None):
173130

174131
@v.bind_key("Shift-C")
175132
def clear_prompts(v):
176-
prompts.data = []
177-
prompts.refresh()
178-
box_prompts.data = []
179-
box_prompts.refresh()
133+
clear_all_prompts(v)
180134

181135
#
182136
# start the viewer

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from napari.utils import progress
77

88
from .. import util
9-
from ..segment_from_prompts import segment_from_mask, segment_from_points
9+
from ..segment_from_prompts import segment_from_mask
1010
from ..visualization import project_embeddings_for_visualization
1111
from .util import (
12-
commit_segmentation_widget, create_prompt_menu,
13-
prompt_layer_to_points, segment_slices_with_prompts,
14-
toggle_label, LABEL_COLOR_CYCLE
12+
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
13+
prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation,
14+
segment_slices_with_prompts, toggle_label, LABEL_COLOR_CYCLE
1515
)
1616

1717

@@ -121,14 +121,26 @@ def segment_slice_wigdet(v: Viewer):
121121
position = v.cursor.position
122122
z = int(position[0])
123123

124-
this_prompts = prompt_layer_to_points(v.layers["prompts"], z)
125-
if this_prompts is None:
124+
point_prompts = prompt_layer_to_points(v.layers["prompts"], z)
125+
# this is a stop prompt, we do nothing
126+
if not point_prompts:
126127
return
127128

128-
points, labels = this_prompts
129-
seg = segment_from_points(PREDICTOR, points, labels, image_embeddings=IMAGE_EMBEDDINGS, i=z)
129+
boxes = prompt_layer_to_boxes(v.layers["box_prompts"], z)
130+
points, labels = point_prompts
130131

131-
v.layers["current_object"].data[z] = seg.squeeze()
132+
shape = v.layers["current_object"].data.shape[1:]
133+
seg = prompt_segmentation(
134+
PREDICTOR, points, labels, boxes, shape, multiple_box_prompts=False,
135+
image_embeddings=IMAGE_EMBEDDINGS, i=z
136+
)
137+
138+
# no prompts were given or prompts were invalid, skip segmentation
139+
if seg is None:
140+
print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
141+
return
142+
143+
v.layers["current_object"].data[z] = seg
132144
v.layers["current_object"].refresh()
133145

134146

@@ -147,7 +159,7 @@ def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, projection: str
147159
with progress(total=shape[0]) as progress_bar:
148160

149161
seg, slices, stop_lower, stop_upper = segment_slices_with_prompts(
150-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar,
162+
PREDICTOR, v.layers["prompts"], v.layers["box_prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar,
151163
)
152164

153165
# step 2: segment the rest of the volume based on smart prompting
@@ -205,6 +217,10 @@ def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_r
205217
)
206218
prompts.edge_color_mode = "cycle"
207219

220+
v.add_shapes(
221+
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts", ndim=3
222+
)
223+
208224
#
209225
# add the widgets
210226
#
@@ -241,8 +257,7 @@ def _toggle_label(event=None):
241257

242258
@v.bind_key("Shift-C")
243259
def clear_prompts(v):
244-
prompts.data = []
245-
prompts.refresh()
260+
clear_all_prompts(v)
246261

247262
#
248263
# start the viewer

micro_sam/sam_annotator/util.py

Lines changed: 130 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
from magicgui.widgets import ComboBox, Container
55
from napari import Viewer
66

7-
from ..segment_from_prompts import segment_from_points
7+
from ..segment_from_prompts import segment_from_box, segment_from_box_and_points, segment_from_points
88

99
# Green and Red
1010
LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"]
1111

1212

13+
def clear_all_prompts(v):
14+
v.layers["prompts"].data = []
15+
v.layers["prompts"].refresh()
16+
if "box_prompts" in v.layers:
17+
v.layers["box_prompts"].data = []
18+
v.layers["box_prompts"].refresh()
19+
20+
1321
@magicgui(call_button="Commit [C]", layer={"choices": ["current_object", "auto_segmentation"]})
1422
def commit_segmentation_widget(v: Viewer, layer: str = "current_object"):
1523
seg = v.layers[layer].data
@@ -25,11 +33,7 @@ def commit_segmentation_widget(v: Viewer, layer: str = "current_object"):
2533
v.layers[layer].refresh()
2634

2735
if layer == "current_object":
28-
v.layers["prompts"].data = []
29-
v.layers["prompts"].refresh()
30-
if "box_prompts" in v.layers:
31-
v.layers["box_prompts"].data = []
32-
v.layers["box_prompts"].refresh()
36+
clear_all_prompts(v)
3337

3438

3539
def create_prompt_menu(points_layer, labels, menu_name="prompt", label_name="label"):
@@ -59,7 +63,8 @@ def prompt_layer_to_points(prompt_layer, i=None, track_id=None):
5963
6064
Arguments:
6165
prompt_layer: the point layer
62-
i [int] - index for the data (required for 3d data)
66+
i [int] - index for the data (required for 3d or timeseries data)
67+
track_id [int] - id of the current track (required for tracking data)
6368
"""
6469

6570
points = prompt_layer.data
@@ -92,6 +97,50 @@ def prompt_layer_to_points(prompt_layer, i=None, track_id=None):
9297
return this_points, this_labels
9398

9499

100+
def prompt_layer_to_boxes(prompt_layer, i=None, track_id=None):
101+
"""Extract box prompts for SAM from shape layer.
102+
103+
Arguments:
104+
prompt_layer: the point layer
105+
i [int] - index for the data (required for 3d or timeseries data)
106+
track_id [int] - id of the current track (required for tracking data)
107+
"""
108+
shape_data = prompt_layer.data
109+
shape_types = prompt_layer.shape_type
110+
assert len(shape_data) == len(shape_types)
111+
112+
if i is None:
113+
# select all boxes that are rectangles
114+
boxes = [data for data, stype in zip(shape_data, shape_types) if stype == "rectangle"]
115+
else:
116+
# we are currently only supporting rectangle shapes.
117+
# other shapes could be supported by providing them as rough mask
118+
# (and also providing the corresponding bounding box)
119+
# but for this we need to figure out the mask prompts for non-square shapes
120+
non_rectangle = [stype != "rectangle" for stype in shape_types]
121+
if any(non_rectangle):
122+
print(f"You have provided {sum(non_rectangle)} shapes that are not rectangles.")
123+
print("We currently do not support these as prompts and they will be ignored.")
124+
boxes = [
125+
data[:, 1:] for data, stype in zip(shape_data, shape_types)
126+
if (stype == "rectangle" and (data[:, 0] == i).all())
127+
]
128+
129+
# TODO support for track_id
130+
# if track_id is not None:
131+
# assert i is not None
132+
# track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))[mask]
133+
# track_id_mask = track_ids == track_id
134+
# this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
135+
# assert len(this_points) == len(this_labels)
136+
137+
# map to correct box format
138+
boxes = [
139+
np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes
140+
]
141+
return boxes
142+
143+
95144
def prompt_layer_to_state(prompt_layer, i):
96145
"""Get the state of the track from the prompt layer.
97146
Only relevant for annotator_tracking.
@@ -116,27 +165,36 @@ def prompt_layer_to_state(prompt_layer, i):
116165
return "track"
117166

118167

119-
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape, progress_bar=None, track_id=None):
168+
def segment_slices_with_prompts(
169+
predictor, point_prompts, box_prompts, image_embeddings, shape, progress_bar=None, track_id=None
170+
):
171+
"""
172+
"""
173+
assert len(shape) == 3
174+
image_shape = shape[1:]
120175
seg = np.zeros(shape, dtype="uint32")
121176

122-
z_values = prompt_layer.data[:, 0]
177+
z_values = point_prompts.data[:, 0]
178+
z_values_boxes = np.concatenate([box[:, 0] for box in box_prompts.data])
179+
180+
# TODO add track id properties to boxes as well, filter z_values_boxes accordingly
123181
if track_id is not None:
124-
track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))
182+
track_ids = np.array(list(map(int, point_prompts.properties["track_id"])))
125183
assert len(track_ids) == len(z_values)
126184
z_values = z_values[track_ids == track_id]
127185

128-
slices = np.unique(z_values).astype("int")
186+
slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int")
129187
stop_lower, stop_upper = False, False
130188

131189
def _update_progress():
132190
if progress_bar is not None:
133191
progress_bar.update(1)
134192

135193
for i in slices:
136-
prompts_i = prompt_layer_to_points(prompt_layer, i, track_id)
194+
points_i = prompt_layer_to_points(point_prompts, i, track_id)
137195

138196
# do we end the segmentation at the outer slices?
139-
if prompts_i is None:
197+
if points_i is None:
140198

141199
if i == slices[0]:
142200
stop_lower = True
@@ -149,14 +207,71 @@ def _update_progress():
149207
_update_progress()
150208
continue
151209

152-
points, labels = prompts_i
153-
seg_i = segment_from_points(predictor, points, labels, image_embeddings=image_embeddings, i=i)
210+
boxes = prompt_layer_to_boxes(box_prompts, i, track_id)
211+
points, labels = points_i
212+
213+
seg_i = prompt_segmentation(
214+
predictor, points, labels, boxes, image_shape, multiple_box_prompts=False,
215+
image_embeddings=image_embeddings, i=i
216+
)
217+
if seg_i is None:
218+
print(f"The prompts at slice or frame {i} are invalid and the segmentation was skipped.")
219+
print("This will lead to a wrong segmentation across slices or frames.")
220+
print(f"Please correct the prompts in {i} and rerun the segmentation.")
221+
continue
222+
154223
seg[i] = seg_i
155224
_update_progress()
156225

157226
return seg, slices, stop_lower, stop_upper
158227

159228

229+
def prompt_segmentation(
230+
predictor, points, labels, boxes, shape, multiple_box_prompts, image_embeddings=None, i=None
231+
):
232+
"""
233+
"""
234+
assert len(points) == len(labels)
235+
have_points = len(points) > 0
236+
have_boxes = len(boxes) > 0
237+
238+
# no prompts were given, return None
239+
if not have_points and not have_boxes:
240+
return
241+
242+
# box and ppint prompts were given
243+
elif have_points and have_boxes:
244+
if len(boxes) > 1:
245+
print("You have provided point prompts and more than one box prompt.")
246+
print("This setting is currently not supported.")
247+
print("When providing both points and prompts you can only segment one object at a time.")
248+
return
249+
seg = segment_from_box_and_points(
250+
predictor, boxes[0], points, labels, image_embeddings=image_embeddings, i=i
251+
).squeeze()
252+
253+
# only point prompts were given
254+
elif have_points and not have_boxes:
255+
seg = segment_from_points(predictor, points, labels, image_embeddings=image_embeddings, i=i).squeeze()
256+
257+
# only box prompts were given
258+
elif not have_points and have_boxes:
259+
seg = np.zeros(shape, dtype="uint32")
260+
261+
if len(boxes) > 1 and not multiple_box_prompts:
262+
print("You have provided more than one box annotation. This is not yet supported in the 3d annotator.")
263+
print("You can only segment one object at a time in 3d.")
264+
return
265+
266+
seg_id = 1
267+
for box in boxes:
268+
mask = segment_from_box(predictor, box, image_embeddings=image_embeddings, i=i).squeeze()
269+
seg[mask] = seg_id
270+
seg_id += 1
271+
272+
return seg
273+
274+
160275
def toggle_label(prompts):
161276
# get the currently selected label
162277
current_properties = prompts.current_properties

0 commit comments

Comments
 (0)