Skip to content

Commit 6203432

Browse files
Merge pull request #22 from computational-cell-analytics/box-prompts
Implement box prompts for all annotators
2 parents c8209f2 + 691d26e commit 6203432

File tree

6 files changed

+359
-74
lines changed

6 files changed

+359
-74
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,27 @@
77
from .. import util
88
from .. import segment_instances
99
from ..visualization import project_embeddings_for_visualization
10-
from ..segment_from_prompts import segment_from_points
1110
from .util import (
12-
commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE
11+
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
12+
prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation, toggle_label, LABEL_COLOR_CYCLE
1313
)
1414

1515

1616
@magicgui(call_button="Segment Object [S]")
1717
def segment_wigdet(v: Viewer):
18+
# get the current box and point prompts
19+
boxes = prompt_layer_to_boxes(v.layers["box_prompts"])
1820
points, labels = prompt_layer_to_points(v.layers["prompts"])
19-
seg = segment_from_points(PREDICTOR, points, labels)
20-
v.layers["current_object"].data = seg.squeeze()
21+
22+
shape = v.layers["current_object"].data.shape
23+
seg = prompt_segmentation(PREDICTOR, points, labels, boxes, shape, multiple_box_prompts=True)
24+
25+
# no prompts were given or prompts were invalid, skip segmentation
26+
if seg is None:
27+
print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
28+
return
29+
30+
v.layers["current_object"].data = seg
2131
v.layers["current_object"].refresh()
2232

2333

@@ -85,6 +95,10 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
8595
)
8696
prompts.edge_color_mode = "cycle"
8797

98+
v.add_shapes(
99+
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts"
100+
)
101+
88102
#
89103
# add the widgets
90104
#
@@ -116,8 +130,7 @@ def _toggle_label(event=None):
116130

117131
@v.bind_key("Shift-C")
118132
def clear_prompts(v):
119-
prompts.data = []
120-
prompts.refresh()
133+
clear_all_prompts(v)
121134

122135
#
123136
# start the viewer

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from napari.utils import progress
77

88
from .. import util
9-
from ..segment_from_prompts import segment_from_mask, segment_from_points
9+
from ..segment_from_prompts import segment_from_mask
1010
from ..visualization import project_embeddings_for_visualization
1111
from .util import (
12-
commit_segmentation_widget, create_prompt_menu,
13-
prompt_layer_to_points, segment_slices_with_prompts,
14-
toggle_label, LABEL_COLOR_CYCLE
12+
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
13+
prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation,
14+
segment_slices_with_prompts, toggle_label, LABEL_COLOR_CYCLE
1515
)
1616

1717

@@ -95,7 +95,9 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
9595

9696
else: # there is a range of more than 2 slices in between -> segment ranges
9797
# segment from bottom
98-
segment_range(z_start, z_mid, 1, np.greater_equal, verbose=verbose)
98+
segment_range(
99+
z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
100+
)
99101
# segment from top
100102
segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
101103
# if the difference between start and stop is even,
@@ -121,14 +123,26 @@ def segment_slice_wigdet(v: Viewer):
121123
position = v.cursor.position
122124
z = int(position[0])
123125

124-
this_prompts = prompt_layer_to_points(v.layers["prompts"], z)
125-
if this_prompts is None:
126+
point_prompts = prompt_layer_to_points(v.layers["prompts"], z)
127+
# this is a stop prompt, we do nothing
128+
if not point_prompts:
126129
return
127130

128-
points, labels = this_prompts
129-
seg = segment_from_points(PREDICTOR, points, labels, image_embeddings=IMAGE_EMBEDDINGS, i=z)
131+
boxes = prompt_layer_to_boxes(v.layers["box_prompts"], z)
132+
points, labels = point_prompts
130133

131-
v.layers["current_object"].data[z] = seg.squeeze()
134+
shape = v.layers["current_object"].data.shape[1:]
135+
seg = prompt_segmentation(
136+
PREDICTOR, points, labels, boxes, shape, multiple_box_prompts=False,
137+
image_embeddings=IMAGE_EMBEDDINGS, i=z
138+
)
139+
140+
# no prompts were given or prompts were invalid, skip segmentation
141+
if seg is None:
142+
print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
143+
return
144+
145+
v.layers["current_object"].data[z] = seg
132146
v.layers["current_object"].refresh()
133147

134148

@@ -147,7 +161,7 @@ def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, projection: str
147161
with progress(total=shape[0]) as progress_bar:
148162

149163
seg, slices, stop_lower, stop_upper = segment_slices_with_prompts(
150-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar,
164+
PREDICTOR, v.layers["prompts"], v.layers["box_prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar,
151165
)
152166

153167
# step 2: segment the rest of the volume based on smart prompting
@@ -205,6 +219,10 @@ def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_r
205219
)
206220
prompts.edge_color_mode = "cycle"
207221

222+
v.add_shapes(
223+
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts", ndim=3
224+
)
225+
208226
#
209227
# add the widgets
210228
#
@@ -241,8 +259,7 @@ def _toggle_label(event=None):
241259

242260
@v.bind_key("Shift-C")
243261
def clear_prompts(v):
244-
prompts.data = []
245-
prompts.refresh()
262+
clear_all_prompts(v)
246263

247264
#
248265
# start the viewer

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 90 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
# from vigra.filters import eccentricityCenters
1212

1313
from .. import util
14-
from ..segment_from_prompts import segment_from_mask, segment_from_points
14+
from ..segment_from_prompts import segment_from_mask
1515
from .util import (
16-
create_prompt_menu, prompt_layer_to_points, prompt_layer_to_state,
16+
create_prompt_menu, clear_all_prompts,
17+
prompt_layer_to_boxes, prompt_layer_to_points,
18+
prompt_layer_to_state, prompt_segmentation,
1719
segment_slices_with_prompts, toggle_label, LABEL_COLOR_CYCLE
1820
)
1921
from ..visualization import project_embeddings_for_visualization
2022

21-
# Magenta and Cyan
22-
STATE_COLOR_CYCLE = ["#FF00FF", "#00FFFF"]
23+
# Cyan (track) and Magenta (division)
24+
STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ]
2325

2426

2527
#
@@ -56,7 +58,7 @@ def _shift_object(mask, motion_model):
5658

5759
# TODO division classifier
5860
def _track_from_prompts(
59-
prompt_layer, seg, predictor, slices, image_embeddings,
61+
point_prompts, box_prompts, seg, predictor, slices, image_embeddings,
6062
stop_upper, threshold, projection,
6163
progress_bar=None, motion_smoothing=0.5,
6264
):
@@ -99,7 +101,9 @@ def _update_motion_model(seg, t, t0, motion_model):
99101
if t in slices:
100102
seg_prev = None
101103
seg_t = seg[t]
102-
track_state = prompt_layer_to_state(prompt_layer, t)
104+
# currently using the box layer doesn't work for keeping track of the track state
105+
# track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
106+
track_state = prompt_layer_to_state(point_prompts, t)
103107

104108
# otherwise project the mask (under the motion model) and segment the next slice from the mask
105109
else:
@@ -138,7 +142,7 @@ def _update_motion_model(seg, t, t0, motion_model):
138142
break
139143

140144
# stop if we are at the last slce
141-
if t == seg.shape[0] - 1:
145+
if t == seg.shape[0]:
142146
break
143147

144148
# stop if we have a division
@@ -180,9 +184,24 @@ def segment_frame_wigdet(v: Viewer):
180184
position = v.cursor.position
181185
t = int(position[0])
182186

183-
this_prompts = prompt_layer_to_points(v.layers["prompts"], t, track_id=CURRENT_TRACK_ID)
184-
points, labels = this_prompts
185-
seg = segment_from_points(PREDICTOR, points, labels, image_embeddings=IMAGE_EMBEDDINGS, i=t)
187+
point_prompts = prompt_layer_to_points(v.layers["prompts"], t, track_id=CURRENT_TRACK_ID)
188+
# this is a stop prompt, we do nothing
189+
if not point_prompts:
190+
return
191+
192+
boxes = prompt_layer_to_boxes(v.layers["box_prompts"], t, track_id=CURRENT_TRACK_ID)
193+
points, labels = point_prompts
194+
195+
shape = v.layers["current_track"].data.shape[1:]
196+
seg = prompt_segmentation(
197+
PREDICTOR, points, labels, boxes, shape, multiple_box_prompts=False,
198+
image_embeddings=IMAGE_EMBEDDINGS, i=t
199+
)
200+
201+
# no prompts were given or prompts were invalid, skip segmentation
202+
if seg is None:
203+
print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
204+
return
186205

187206
# clear the old segmentation for this track_id
188207
old_mask = v.layers["current_track"].data[t] == CURRENT_TRACK_ID
@@ -199,24 +218,23 @@ def track_objet_widget(
199218
):
200219
shape = v.layers["raw"].data.shape
201220

202-
# choose mask projection for square images and bounding box projection otherwise
203-
# (because mask projection does not work properly for non-square images yet)
204-
if projection == "default":
205-
projection_ = "mask" if shape[1] == shape[2] else "bounding_box"
206-
else:
207-
projection_ = projection
221+
# we use the bounding box projection method as default which generally seems to work better for larger changes
222+
# between frames (which is pretty tyipical for tracking compared to 3d segmentation)
223+
projection_ = "bounding_box" if projection == "default" else projection
208224

209225
with progress(total=shape[0]) as progress_bar:
210226
# step 1: segment all slices with prompts
211227
seg, slices, _, stop_upper = segment_slices_with_prompts(
212-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape,
228+
PREDICTOR, v.layers["prompts"], v.layers["box_prompts"], IMAGE_EMBEDDINGS, shape,
213229
progress_bar=progress_bar, track_id=CURRENT_TRACK_ID
214230
)
215231

216232
# step 2: track the object starting from the lowest annotated slice
217233
seg, has_division = _track_from_prompts(
218-
v.layers["prompts"], seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, threshold=iou_threshold,
219-
projection=projection_, progress_bar=progress_bar, motion_smoothing=motion_smoothing,
234+
v.layers["prompts"], v.layers["box_prompts"], seg,
235+
PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper,
236+
threshold=iou_threshold, projection=projection_,
237+
progress_bar=progress_bar, motion_smoothing=motion_smoothing,
220238
)
221239

222240
# if a division has occurred and it's the first time it occurred for this track
@@ -231,7 +249,7 @@ def track_objet_widget(
231249
v.layers["current_track"].refresh()
232250

233251

234-
def create_tracking_menu(points_layer, states, track_ids):
252+
def create_tracking_menu(points_layer, box_layer, states, track_ids):
235253
state_menu = ComboBox(label="track_state", choices=states)
236254
track_id_menu = ComboBox(label="track_id", choices=list(map(str, track_ids)))
237255
tracking_widget = Container(widgets=[state_menu, track_id_menu])
@@ -245,11 +263,25 @@ def update_track_id(event):
245263
global CURRENT_TRACK_ID
246264
new_id = str(points_layer.current_properties["track_id"][0])
247265
if new_id != track_id_menu.value:
248-
state_menu.value = new_id
266+
track_id_menu.value = new_id
267+
CURRENT_TRACK_ID = int(new_id)
268+
269+
# def update_state_boxes(event):
270+
# new_state = str(box_layer.current_properties["state"][0])
271+
# if new_state != state_menu.value:
272+
# state_menu.value = new_state
273+
274+
def update_track_id_boxes(event):
275+
global CURRENT_TRACK_ID
276+
new_id = str(box_layer.current_properties["track_id"][0])
277+
if new_id != track_id_menu.value:
278+
track_id_menu.value = new_id
249279
CURRENT_TRACK_ID = int(new_id)
250280

251281
points_layer.events.current_properties.connect(update_state)
252282
points_layer.events.current_properties.connect(update_track_id)
283+
# box_layer.events.current_properties.connect(update_state_boxes)
284+
box_layer.events.current_properties.connect(update_track_id_boxes)
253285

254286
def state_changed(new_state):
255287
current_properties = points_layer.current_properties
@@ -264,8 +296,23 @@ def track_id_changed(new_track_id):
264296
points_layer.current_properties = current_properties
265297
CURRENT_TRACK_ID = int(new_track_id)
266298

299+
# def state_changed_boxes(new_state):
300+
# current_properties = box_layer.current_properties
301+
# current_properties["state"] = np.array([new_state])
302+
# box_layer.current_properties = current_properties
303+
# box_layer.refresh_colors()
304+
305+
def track_id_changed_boxes(new_track_id):
306+
global CURRENT_TRACK_ID
307+
current_properties = box_layer.current_properties
308+
current_properties["track_id"] = np.array([new_track_id])
309+
box_layer.current_properties = current_properties
310+
CURRENT_TRACK_ID = int(new_track_id)
311+
267312
state_menu.changed.connect(state_changed)
268313
track_id_menu.changed.connect(track_id_changed)
314+
# state_menu.changed.connect(state_changed_boxes)
315+
track_id_menu.changed.connect(track_id_changed_boxes)
269316

270317
state_menu.set_choice("track")
271318
return tracking_widget
@@ -295,8 +342,7 @@ def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
295342
v.layers[layer].data = np.zeros(shape, dtype="uint32")
296343
v.layers[layer].refresh()
297344

298-
v.layers["prompts"].data = []
299-
v.layers["prompts"].refresh()
345+
clear_all_prompts(v)
300346

301347

302348
def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking_result=None, model_type="vit_h"):
@@ -333,7 +379,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
333379
# add the widgets
334380
#
335381
labels = ["positive", "negative"]
336-
state_labels = ["division", "track"]
382+
state_labels = ["track", "division"]
337383
prompts = v.add_points(
338384
data=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], # FIXME workaround
339385
name="prompts",
@@ -354,6 +400,24 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
354400
prompts.edge_color_mode = "cycle"
355401
prompts.face_color_mode = "cycle"
356402

403+
# using the box layer to set divisions currently doesn't work
404+
# (and setting new track ids also doesn't work, but keeping track of them in the properties is working)
405+
box_prompts = v.add_shapes(
406+
data=[
407+
np.array([[0, 0, 0], [0, 0, 10], [0, 10, 0], [0, 10, 10]]),
408+
np.array([[0, 0, 0], [0, 0, 11], [0, 11, 0], [0, 11, 11]]),
409+
], # FIXME workaround
410+
shape_type="rectangle", # FIXME workaround
411+
edge_width=4, ndim=3,
412+
face_color="transparent",
413+
name="box_prompts",
414+
edge_color="green",
415+
properties={"track_id": ["1", "1"]},
416+
# properties={"track_id": ["1", "1"], "state": state_labels},
417+
# edge_color_cycle=STATE_COLOR_CYCLE,
418+
)
419+
# box_prompts.edge_color_mode = "cycle"
420+
357421
#
358422
# add the widgets
359423
#
@@ -363,7 +427,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
363427
prompt_widget = create_prompt_menu(prompts, labels)
364428
v.window.add_dock_widget(prompt_widget)
365429

366-
TRACKING_WIDGET = create_tracking_menu(prompts, state_labels, list(LINEAGE.keys()))
430+
TRACKING_WIDGET = create_tracking_menu(prompts, box_prompts, state_labels, list(LINEAGE.keys()))
367431
v.window.add_dock_widget(TRACKING_WIDGET)
368432

369433
v.window.add_dock_widget(segment_frame_wigdet)
@@ -392,8 +456,7 @@ def _commit(v):
392456

393457
@v.bind_key("Shift-C")
394458
def clear_prompts(v):
395-
prompts.data = []
396-
prompts.refresh()
459+
clear_all_prompts(v)
397460

398461
#
399462
# start the viewer

0 commit comments

Comments
 (0)