Skip to content

Commit 91348ac

Browse files
Merge pull request #6 from computational-cell-analytics/tracking
First tracking annotator prototype
2 parents cf1ca66 + 72d275a commit 91348ac

File tree

2 files changed

+162
-35
lines changed

2 files changed

+162
-35
lines changed

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 145 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from magicgui import magicgui
5+
from magicgui.widgets import ComboBox, Container
56
from napari import Viewer
67
from napari.utils import progress
78
from scipy.ndimage import shift
@@ -47,12 +48,6 @@ def compute_center(t):
4748
return move.astype("float64")
4849

4950

50-
def _update_motion_model(motion_model, move, motion_smoothing):
51-
alpha = motion_smoothing
52-
motion_model = alpha * motion_model + (1 - alpha) * move
53-
return motion_model
54-
55-
5651
def _shift_object(mask, motion_model):
5752
mask_shifted = np.zeros_like(mask)
5853
shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False)
@@ -76,40 +71,46 @@ def _update_progress():
7671
progress_bar.update(1)
7772

7873
# shift the segmentation based on the motion model and update the motion model
79-
def _update_motion(seg, t, t0, motion_model):
80-
seg_prev = seg[t - 1]
81-
82-
if t == t0 + 1: # this is the second frame, we don't have a motion model yet
74+
def _update_motion_model(seg, t, t0, motion_model):
75+
if t in (t0, t0 + 1): # this is the first or second frame, we don't have a motion yet
8376
pass
8477
elif t == t0 + 2: # this the third frame, we initialize the motion model
8578
current_move = _compute_movement(seg, t - 1, t - 2)
8679
motion_model = current_move
8780
else: # we already have a motion model and update it
8881
current_move = _compute_movement(seg, t - 1, t - 2)
89-
motion_model = _update_motion_model(motion_model, current_move, motion_smoothing)
82+
alpha = motion_smoothing
83+
motion_model = alpha * motion_model + (1 - alpha) * current_move
9084

91-
if motion_model is not None: # shift the segmentation according to the motion model
92-
seg_prev = _shift_object(seg_prev, motion_model)
93-
94-
return seg_prev, motion_model
85+
return motion_model
9586

87+
has_division = False
9688
motion_model = None
9789
verbose = False
9890

9991
t0 = int(slices.min())
10092
t = t0 + 1
10193
while True:
10294

103-
if t in slices: # this is a slice with prompts
95+
# update the motion model
96+
motion_model = _update_motion_model(seg, t, t0, motion_model)
97+
98+
# use the segmentation from prompts if we are in a slice with prompts
99+
if t in slices:
104100
seg_prev = None
105101
seg_t = seg[t]
106102
track_state = prompt_layer_to_state(prompt_layer, t)
107-
# TODO what do we do with the motion model here?
108103

109-
else: # this is a slice without prompts
110-
seg_prev, motion_model = _update_motion(seg, t, t0, motion_model)
104+
# otherwise project the mask (under the motion model) and segment the next slice from the mask
105+
else:
111106
if verbose:
112107
print(f"Tracking object in frame {t} with movement {motion_model}")
108+
109+
seg_prev = seg[t - 1]
110+
# shift the segmentation according to the motion model
111+
if motion_model is not None:
112+
seg_prev = _shift_object(seg_prev, motion_model)
113+
113114
seg_t = segment_from_mask(predictor, seg_prev, image_embeddings=image_embeddings, i=t,
114115
use_mask=use_mask, use_box=use_box)
115116
track_state = "track"
@@ -142,9 +143,31 @@ def _update_motion(seg, t, t0, motion_model):
142143

143144
# stop if we have a division
144145
if track_state == "division":
146+
has_division = True
145147
break
146148

147-
return seg
149+
return seg, has_division
150+
151+
152+
def _update_lineage():
153+
global LINEAGE, TRACKING_WIDGET
154+
mother = CURRENT_TRACK_ID
155+
assert mother in LINEAGE
156+
assert len(LINEAGE[mother]) == 0
157+
158+
daughter1, daughter2 = CURRENT_TRACK_ID + 1, CURRENT_TRACK_ID + 2
159+
LINEAGE[mother] = [daughter1, daughter2]
160+
LINEAGE[daughter1] = []
161+
LINEAGE[daughter2] = []
162+
163+
# update the choices in the track_id menu
164+
track_ids = list(map(str, LINEAGE.keys()))
165+
TRACKING_WIDGET[1].choices = track_ids
166+
167+
# not sure if this does the right thing.
168+
# for now the user has to take care of this manually
169+
# # reset the state to track
170+
# TRACKING_WIDGET[0].set_choice("track")
148171

149172

150173
#
@@ -157,11 +180,16 @@ def segment_frame_wigdet(v: Viewer):
157180
position = v.cursor.position
158181
t = int(position[0])
159182

160-
this_prompts = prompt_layer_to_points(v.layers["prompts"], t)
183+
this_prompts = prompt_layer_to_points(v.layers["prompts"], t, track_id=CURRENT_TRACK_ID)
161184
points, labels = this_prompts
162185
seg = segment_from_points(PREDICTOR, points, labels, image_embeddings=IMAGE_EMBEDDINGS, i=t)
163186

164-
v.layers["current_track"].data[t] = seg.squeeze()
187+
# clear the old segmentation for this track_id
188+
old_mask = v.layers["current_track"].data[t] == CURRENT_TRACK_ID
189+
v.layers["current_track"].data[t][old_mask] = 0
190+
# set the new segmentation
191+
new_mask = seg.squeeze() == 1
192+
v.layers["current_track"].data[t][new_mask] = CURRENT_TRACK_ID
165193
v.layers["current_track"].refresh()
166194

167195

@@ -181,26 +209,107 @@ def track_objet_widget(
181209
with progress(total=shape[0]) as progress_bar:
182210
# step 1: segment all slices with prompts
183211
seg, slices, _, stop_upper = segment_slices_with_prompts(
184-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar
212+
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape,
213+
progress_bar=progress_bar, track_id=CURRENT_TRACK_ID
185214
)
186215

187216
# step 2: track the object starting from the lowest annotated slice
188-
seg = _track_from_prompts(
217+
seg, has_division = _track_from_prompts(
189218
v.layers["prompts"], seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, threshold=iou_threshold,
190219
projection=projection_, progress_bar=progress_bar, motion_smoothing=motion_smoothing,
191220
)
192221

193-
v.layers["current_track"].data = seg
222+
# if a division has occurred and it's the first time it occurred for this track
223+
# we need to create the two daughter tracks and update the lineage
224+
if has_division and (len(LINEAGE[CURRENT_TRACK_ID]) == 0):
225+
_update_lineage()
226+
227+
# clear the old track mask
228+
v.layers["current_track"].data[v.layers["current_track"].data == CURRENT_TRACK_ID] = 0
229+
# set the new track mask
230+
v.layers["current_track"].data[seg == 1] = CURRENT_TRACK_ID
194231
v.layers["current_track"].refresh()
195232

196233

234+
def create_tracking_menu(points_layer, states, track_ids):
235+
state_menu = ComboBox(label="track_state", choices=states)
236+
track_id_menu = ComboBox(label="track_id", choices=list(map(str, track_ids)))
237+
tracking_widget = Container(widgets=[state_menu, track_id_menu])
238+
239+
def update_state(event):
240+
new_state = str(points_layer.current_properties["state"][0])
241+
if new_state != state_menu.value:
242+
state_menu.value = new_state
243+
244+
def update_track_id(event):
245+
global CURRENT_TRACK_ID
246+
new_id = str(points_layer.current_properties["track_id"][0])
247+
if new_id != track_id_menu.value:
248+
state_menu.value = new_id
249+
CURRENT_TRACK_ID = int(new_id)
250+
251+
points_layer.events.current_properties.connect(update_state)
252+
points_layer.events.current_properties.connect(update_track_id)
253+
254+
def state_changed(new_state):
255+
current_properties = points_layer.current_properties
256+
current_properties["state"] = np.array([new_state])
257+
points_layer.current_properties = current_properties
258+
points_layer.refresh_colors()
259+
260+
def track_id_changed(new_track_id):
261+
global CURRENT_TRACK_ID
262+
current_properties = points_layer.current_properties
263+
current_properties["track_id"] = np.array([new_track_id])
264+
points_layer.current_properties = current_properties
265+
CURRENT_TRACK_ID = int(new_track_id)
266+
267+
state_menu.changed.connect(state_changed)
268+
track_id_menu.changed.connect(track_id_changed)
269+
270+
state_menu.set_choice("track")
271+
return tracking_widget
272+
273+
274+
@magicgui(call_button="Commit [C]", layer={"choices": ["current_track"]})
275+
def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
276+
global CURRENT_TRACK_ID, LINEAGE, TRACKING_WIDGET
277+
278+
seg = v.layers[layer].data
279+
280+
id_offset = int(v.layers["committed_tracks"].data.max())
281+
mask = seg != 0
282+
283+
v.layers["committed_tracks"].data[mask] = (seg[mask] + id_offset)
284+
v.layers["committed_tracks"].refresh()
285+
286+
# reset the lineage and track id
287+
CURRENT_TRACK_ID = 1
288+
LINEAGE = {1: []}
289+
290+
# reset the choices in the track_id menu
291+
track_ids = list(map(str, LINEAGE.keys()))
292+
TRACKING_WIDGET[1].choices = track_ids
293+
294+
shape = v.layers["raw"].data.shape
295+
v.layers[layer].data = np.zeros(shape, dtype="uint32")
296+
v.layers[layer].refresh()
297+
298+
v.layers["prompts"].data = []
299+
v.layers["prompts"].refresh()
300+
301+
197302
def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
198-
# for access to the predictor and the image embeddings in the widgets
199-
global PREDICTOR, IMAGE_EMBEDDINGS, NEXT_ID
200-
NEXT_ID = 1
303+
# global state
304+
global PREDICTOR, IMAGE_EMBEDDINGS, CURRENT_TRACK_ID, LINEAGE
305+
global TRACKING_WIDGET
306+
201307
PREDICTOR = util.get_sam_model()
202308
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
203309

310+
CURRENT_TRACK_ID = 1
311+
LINEAGE = {1: []}
312+
204313
#
205314
# initialize the viewer and add layers
206315
#
@@ -227,7 +336,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
227336
properties={
228337
"label": labels,
229338
"state": state_labels,
230-
# "track_id": [1, 1],
339+
"track_id": ["1", "1"], # NOTE we use string to avoid pandas warnings...
231340
},
232341
edge_color="label",
233342
edge_color_cycle=LABEL_COLOR_CYCLE,
@@ -250,11 +359,12 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
250359
prompt_widget = create_prompt_menu(prompts, labels)
251360
v.window.add_dock_widget(prompt_widget)
252361

253-
state_widget = create_prompt_menu(prompts, state_labels, menu_name="state", label_name="state")
254-
v.window.add_dock_widget(state_widget)
362+
TRACKING_WIDGET = create_tracking_menu(prompts, state_labels, list(LINEAGE.keys()))
363+
v.window.add_dock_widget(TRACKING_WIDGET)
255364

256365
v.window.add_dock_widget(segment_frame_wigdet)
257366
v.window.add_dock_widget(track_objet_widget)
367+
v.window.add_dock_widget(commit_tracking_widget)
258368

259369
#
260370
# key bindings
@@ -272,6 +382,10 @@ def _track_object(v):
272382
def _toggle_label(event=None):
273383
toggle_label(prompts)
274384

385+
@v.bind_key("c")
386+
def _commit(v):
387+
commit_tracking_widget(v)
388+
275389
@v.bind_key("Shift-C")
276390
def clear_prompts(v):
277391
prompts.data = []

micro_sam/sam_annotator/util.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def label_changed(new_label):
5151
return label_widget
5252

5353

54-
def prompt_layer_to_points(prompt_layer, i=None):
54+
def prompt_layer_to_points(prompt_layer, i=None, track_id=None):
5555
"""Extract point prompts for SAM from point layer.
5656
5757
Arguments:
@@ -73,6 +73,13 @@ def prompt_layer_to_points(prompt_layer, i=None):
7373
this_labels = labels[mask]
7474
assert len(this_points) == len(this_labels)
7575

76+
if track_id is not None:
77+
assert i is not None
78+
track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))[mask]
79+
track_id_mask = track_ids == track_id
80+
this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
81+
assert len(this_points) == len(this_labels)
82+
7683
this_labels = np.array([1 if label == "positive" else 0 for label in this_labels])
7784
# a single point with a negative label is interpreted as 'stop' signal
7885
# in this case we return None
@@ -106,18 +113,24 @@ def prompt_layer_to_state(prompt_layer, i):
106113
return "track"
107114

108115

109-
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape, progress_bar=None):
116+
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape, progress_bar=None, track_id=None):
110117
seg = np.zeros(shape, dtype="uint32")
111118

112-
slices = np.unique(prompt_layer.data[:, 0]).astype("int")
119+
z_values = prompt_layer.data[:, 0]
120+
if track_id is not None:
121+
track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))
122+
assert len(track_ids) == len(z_values)
123+
z_values = z_values[track_ids == track_id]
124+
125+
slices = np.unique(z_values).astype("int")
113126
stop_lower, stop_upper = False, False
114127

115128
def _update_progress():
116129
if progress_bar is not None:
117130
progress_bar.update(1)
118131

119132
for i in slices:
120-
prompts_i = prompt_layer_to_points(prompt_layer, i)
133+
prompts_i = prompt_layer_to_points(prompt_layer, i, track_id)
121134

122135
# do we end the segmentation at the outer slices?
123136
if prompts_i is None:

0 commit comments

Comments
 (0)