Skip to content

Commit aa414a8

Browse files
Enable tracking annotator with box prompts
1 parent 4827d75 commit aa414a8

File tree

2 files changed

+142
-38
lines changed

2 files changed

+142
-38
lines changed

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
# from vigra.filters import eccentricityCenters
1212

1313
from .. import util
14-
from ..segment_from_prompts import segment_from_mask, segment_from_points
14+
from ..segment_from_prompts import segment_from_mask
1515
from .util import (
16-
create_prompt_menu, prompt_layer_to_points, prompt_layer_to_state,
16+
create_prompt_menu, clear_all_prompts,
17+
prompt_layer_to_boxes, prompt_layer_to_points,
18+
prompt_layer_to_state, prompt_segmentation,
1719
segment_slices_with_prompts, toggle_label, LABEL_COLOR_CYCLE
1820
)
1921
from ..visualization import project_embeddings_for_visualization
2022

21-
# Magenta and Cyan
22-
STATE_COLOR_CYCLE = ["#FF00FF", "#00FFFF"]
23+
# Cyan (track) and Magenta (division)
24+
STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ]
2325

2426

2527
#
@@ -56,7 +58,7 @@ def _shift_object(mask, motion_model):
5658

5759
# TODO division classifier
5860
def _track_from_prompts(
59-
prompt_layer, seg, predictor, slices, image_embeddings,
61+
point_prompts, box_prompts, seg, predictor, slices, image_embeddings,
6062
stop_upper, threshold, projection,
6163
progress_bar=None, motion_smoothing=0.5,
6264
):
@@ -99,7 +101,9 @@ def _update_motion_model(seg, t, t0, motion_model):
99101
if t in slices:
100102
seg_prev = None
101103
seg_t = seg[t]
102-
track_state = prompt_layer_to_state(prompt_layer, t)
104+
# currently using the box layer doesn't work for keeping track of the track state
105+
# track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
106+
track_state = prompt_layer_to_state(point_prompts, t)
103107

104108
# otherwise project the mask (under the motion model) and segment the next slice from the mask
105109
else:
@@ -180,9 +184,24 @@ def segment_frame_wigdet(v: Viewer):
180184
position = v.cursor.position
181185
t = int(position[0])
182186

183-
this_prompts = prompt_layer_to_points(v.layers["prompts"], t, track_id=CURRENT_TRACK_ID)
184-
points, labels = this_prompts
185-
seg = segment_from_points(PREDICTOR, points, labels, image_embeddings=IMAGE_EMBEDDINGS, i=t)
187+
point_prompts = prompt_layer_to_points(v.layers["prompts"], t, track_id=CURRENT_TRACK_ID)
188+
# this is a stop prompt, we do nothing
189+
if not point_prompts:
190+
return
191+
192+
boxes = prompt_layer_to_boxes(v.layers["box_prompts"], t, track_id=CURRENT_TRACK_ID)
193+
points, labels = point_prompts
194+
195+
shape = v.layers["current_track"].data.shape[1:]
196+
seg = prompt_segmentation(
197+
PREDICTOR, points, labels, boxes, shape, multiple_box_prompts=False,
198+
image_embeddings=IMAGE_EMBEDDINGS, i=t
199+
)
200+
201+
# no prompts were given or prompts were invalid, skip segmentation
202+
if seg is None:
203+
print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
204+
return
186205

187206
# clear the old segmentation for this track_id
188207
old_mask = v.layers["current_track"].data[t] == CURRENT_TRACK_ID
@@ -209,14 +228,16 @@ def track_objet_widget(
209228
with progress(total=shape[0]) as progress_bar:
210229
# step 1: segment all slices with prompts
211230
seg, slices, _, stop_upper = segment_slices_with_prompts(
212-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape,
231+
PREDICTOR, v.layers["prompts"], v.layers["box_prompts"], IMAGE_EMBEDDINGS, shape,
213232
progress_bar=progress_bar, track_id=CURRENT_TRACK_ID
214233
)
215234

216235
# step 2: track the object starting from the lowest annotated slice
217236
seg, has_division = _track_from_prompts(
218-
v.layers["prompts"], seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, threshold=iou_threshold,
219-
projection=projection_, progress_bar=progress_bar, motion_smoothing=motion_smoothing,
237+
v.layers["prompts"], v.layers["box_prompts"], seg,
238+
PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper,
239+
threshold=iou_threshold, projection=projection_,
240+
progress_bar=progress_bar, motion_smoothing=motion_smoothing,
220241
)
221242

222243
# if a division has occurred and it's the first time it occurred for this track
@@ -231,7 +252,7 @@ def track_objet_widget(
231252
v.layers["current_track"].refresh()
232253

233254

234-
def create_tracking_menu(points_layer, states, track_ids):
255+
def create_tracking_menu(points_layer, box_layer, states, track_ids):
235256
state_menu = ComboBox(label="track_state", choices=states)
236257
track_id_menu = ComboBox(label="track_id", choices=list(map(str, track_ids)))
237258
tracking_widget = Container(widgets=[state_menu, track_id_menu])
@@ -245,11 +266,25 @@ def update_track_id(event):
245266
global CURRENT_TRACK_ID
246267
new_id = str(points_layer.current_properties["track_id"][0])
247268
if new_id != track_id_menu.value:
248-
state_menu.value = new_id
269+
track_id_menu.value = new_id
270+
CURRENT_TRACK_ID = int(new_id)
271+
272+
# def update_state_boxes(event):
273+
# new_state = str(box_layer.current_properties["state"][0])
274+
# if new_state != state_menu.value:
275+
# state_menu.value = new_state
276+
277+
def update_track_id_boxes(event):
278+
global CURRENT_TRACK_ID
279+
new_id = str(box_layer.current_properties["track_id"][0])
280+
if new_id != track_id_menu.value:
281+
track_id_menu.value = new_id
249282
CURRENT_TRACK_ID = int(new_id)
250283

251284
points_layer.events.current_properties.connect(update_state)
252285
points_layer.events.current_properties.connect(update_track_id)
286+
# box_layer.events.current_properties.connect(update_state_boxes)
287+
box_layer.events.current_properties.connect(update_track_id_boxes)
253288

254289
def state_changed(new_state):
255290
current_properties = points_layer.current_properties
@@ -264,8 +299,23 @@ def track_id_changed(new_track_id):
264299
points_layer.current_properties = current_properties
265300
CURRENT_TRACK_ID = int(new_track_id)
266301

302+
# def state_changed_boxes(new_state):
303+
# current_properties = box_layer.current_properties
304+
# current_properties["state"] = np.array([new_state])
305+
# box_layer.current_properties = current_properties
306+
# box_layer.refresh_colors()
307+
308+
def track_id_changed_boxes(new_track_id):
309+
global CURRENT_TRACK_ID
310+
current_properties = box_layer.current_properties
311+
current_properties["track_id"] = np.array([new_track_id])
312+
box_layer.current_properties = current_properties
313+
CURRENT_TRACK_ID = int(new_track_id)
314+
267315
state_menu.changed.connect(state_changed)
268316
track_id_menu.changed.connect(track_id_changed)
317+
# state_menu.changed.connect(state_changed_boxes)
318+
track_id_menu.changed.connect(track_id_changed_boxes)
269319

270320
state_menu.set_choice("track")
271321
return tracking_widget
@@ -295,8 +345,7 @@ def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
295345
v.layers[layer].data = np.zeros(shape, dtype="uint32")
296346
v.layers[layer].refresh()
297347

298-
v.layers["prompts"].data = []
299-
v.layers["prompts"].refresh()
348+
clear_all_prompts(v)
300349

301350

302351
def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking_result=None, model_type="vit_h"):
@@ -333,7 +382,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
333382
# add the widgets
334383
#
335384
labels = ["positive", "negative"]
336-
state_labels = ["division", "track"]
385+
state_labels = ["track", "division"]
337386
prompts = v.add_points(
338387
data=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], # FIXME workaround
339388
name="prompts",
@@ -354,6 +403,24 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
354403
prompts.edge_color_mode = "cycle"
355404
prompts.face_color_mode = "cycle"
356405

406+
# using the box layer to set divisions currently doesn't work
407+
# (and setting new track ids also doesn't work, but keeping track of them in the properties is working)
408+
box_prompts = v.add_shapes(
409+
data=[
410+
np.array([[0, 0, 0], [0, 0, 10], [0, 10, 0], [0, 10, 10]]),
411+
np.array([[0, 0, 0], [0, 0, 11], [0, 11, 0], [0, 11, 11]]),
412+
], # FIXME workaround
413+
shape_type="rectangle", # FIXME workaround
414+
edge_width=4, ndim=3,
415+
face_color="transparent",
416+
name="box_prompts",
417+
edge_color="green",
418+
properties={"track_id": ["1", "1"]},
419+
# properties={"track_id": ["1", "1"], "state": state_labels},
420+
# edge_color_cycle=STATE_COLOR_CYCLE,
421+
)
422+
# box_prompts.edge_color_mode = "cycle"
423+
357424
#
358425
# add the widgets
359426
#
@@ -363,7 +430,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
363430
prompt_widget = create_prompt_menu(prompts, labels)
364431
v.window.add_dock_widget(prompt_widget)
365432

366-
TRACKING_WIDGET = create_tracking_menu(prompts, state_labels, list(LINEAGE.keys()))
433+
TRACKING_WIDGET = create_tracking_menu(prompts, box_prompts, state_labels, list(LINEAGE.keys()))
367434
v.window.add_dock_widget(TRACKING_WIDGET)
368435

369436
v.window.add_dock_widget(segment_frame_wigdet)
@@ -392,8 +459,7 @@ def _commit(v):
392459

393460
@v.bind_key("Shift-C")
394461
def clear_prompts(v):
395-
prompts.data = []
396-
prompts.refresh()
462+
clear_all_prompts(v)
397463

398464
#
399465
# start the viewer

micro_sam/sam_annotator/util.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,19 @@ def prompt_layer_to_boxes(prompt_layer, i=None, track_id=None):
121121
if any(non_rectangle):
122122
print(f"You have provided {sum(non_rectangle)} shapes that are not rectangles.")
123123
print("We currently do not support these as prompts and they will be ignored.")
124-
boxes = [
125-
data[:, 1:] for data, stype in zip(shape_data, shape_types)
126-
if (stype == "rectangle" and (data[:, 0] == i).all())
127-
]
128-
129-
# TODO support for track_id
130-
# if track_id is not None:
131-
# assert i is not None
132-
# track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))[mask]
133-
# track_id_mask = track_ids == track_id
134-
# this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
135-
# assert len(this_points) == len(this_labels)
124+
125+
if track_id is None:
126+
boxes = [
127+
data[:, 1:] for data, stype in zip(shape_data, shape_types)
128+
if (stype == "rectangle" and (data[:, 0] == i).all())
129+
]
130+
else:
131+
track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))
132+
assert len(track_ids) == len(shape_data)
133+
boxes = [
134+
data[:, 1:] for data, stype, this_track_id in zip(shape_data, shape_types, track_ids)
135+
if (stype == "rectangle" and (data[:, 0] == i).all() and this_track_id == track_id)
136+
]
136137

137138
# map to correct box format
138139
boxes = [
@@ -147,7 +148,7 @@ def prompt_layer_to_state(prompt_layer, i):
147148
148149
Arguments:
149150
prompt_layer: the point layer
150-
i [int] - index for the data (required for 3d data)
151+
i [int] - frame of the data
151152
"""
152153
state = prompt_layer.properties["state"]
153154

@@ -165,6 +166,39 @@ def prompt_layer_to_state(prompt_layer, i):
165166
return "track"
166167

167168

169+
def prompt_layers_to_state(point_layer, box_layer, i):
170+
"""Get the state of the track from the point and box prompt layer.
171+
Only relevant for annotator_tracking.
172+
173+
Arguments:
174+
point_layer: the point layer
175+
box_layer: the box layer
176+
i [int] - frame of the data
177+
"""
178+
state = point_layer.properties["state"]
179+
180+
points = point_layer.data
181+
assert points.shape[1] == 3, f"{points.shape}"
182+
mask = points[:, 0] == i
183+
if mask.sum() > 0:
184+
this_state = state[mask].tolist()
185+
else:
186+
this_state = []
187+
188+
box_states = box_layer.properties["state"]
189+
this_box_states = [
190+
state for box, state in zip(box_layer.data, box_states)
191+
if (box[:, 0] == i).all()
192+
]
193+
this_state.extend(this_box_states)
194+
195+
# we set the state to 'division' if at least one point in this frame has a division label
196+
if any(st == "division" for st in this_state):
197+
return "division"
198+
else:
199+
return "track"
200+
201+
168202
def segment_slices_with_prompts(
169203
predictor, point_prompts, box_prompts, image_embeddings, shape, progress_bar=None, track_id=None
170204
):
@@ -175,13 +209,17 @@ def segment_slices_with_prompts(
175209
seg = np.zeros(shape, dtype="uint32")
176210

177211
z_values = point_prompts.data[:, 0]
178-
z_values_boxes = np.concatenate([box[:, 0] for box in box_prompts.data])
212+
z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\
213+
np.zeros(0, dtype="int")
179214

180-
# TODO add track id properties to boxes as well, filter z_values_boxes accordingly
181215
if track_id is not None:
182-
track_ids = np.array(list(map(int, point_prompts.properties["track_id"])))
183-
assert len(track_ids) == len(z_values)
184-
z_values = z_values[track_ids == track_id]
216+
track_ids_points = np.array(list(map(int, point_prompts.properties["track_id"])))
217+
assert len(track_ids_points) == len(z_values)
218+
z_values = z_values[track_ids_points == track_id]
219+
220+
track_ids_boxes = np.array(list(map(int, box_prompts.properties["track_id"])))
221+
assert len(track_ids_boxes) == len(z_values_boxes), f"{len(track_ids_boxes)}, {len(z_values_boxes)}"
222+
z_values_boxes = z_values_boxes[track_ids_boxes == track_id]
185223

186224
slices = np.unique(np.concatenate([z_values, z_values_boxes])).astype("int")
187225
stop_lower, stop_upper = False, False

0 commit comments

Comments
 (0)