22import numpy as np
33
44from magicgui import magicgui
5+ from magicgui .widgets import ComboBox , Container
56from napari import Viewer
67from napari .utils import progress
78from scipy .ndimage import shift
@@ -47,12 +48,6 @@ def compute_center(t):
4748 return move .astype ("float64" )
4849
4950
50- def _update_motion_model (motion_model , move , motion_smoothing ):
51- alpha = motion_smoothing
52- motion_model = alpha * motion_model + (1 - alpha ) * move
53- return motion_model
54-
55-
5651def _shift_object (mask , motion_model ):
5752 mask_shifted = np .zeros_like (mask )
5853 shift (mask , motion_model , output = mask_shifted , order = 0 , prefilter = False )
@@ -76,40 +71,46 @@ def _update_progress():
7671 progress_bar .update (1 )
7772
7873 # shift the segmentation based on the motion model and update the motion model
79- def _update_motion (seg , t , t0 , motion_model ):
80- seg_prev = seg [t - 1 ]
81-
82- if t == t0 + 1 : # this is the second frame, we don't have a motion model yet
74+ def _update_motion_model (seg , t , t0 , motion_model ):
75+ if t in (t0 , t0 + 1 ): # this is the first or second frame, we don't have a motion yet
8376 pass
8477 elif t == t0 + 2 : # this the third frame, we initialize the motion model
8578 current_move = _compute_movement (seg , t - 1 , t - 2 )
8679 motion_model = current_move
8780 else : # we already have a motion model and update it
8881 current_move = _compute_movement (seg , t - 1 , t - 2 )
89- motion_model = _update_motion_model (motion_model , current_move , motion_smoothing )
82+ alpha = motion_smoothing
83+ motion_model = alpha * motion_model + (1 - alpha ) * current_move
9084
91- if motion_model is not None : # shift the segmentation according to the motion model
92- seg_prev = _shift_object (seg_prev , motion_model )
93-
94- return seg_prev , motion_model
85+ return motion_model
9586
87+ has_division = False
9688 motion_model = None
9789 verbose = False
9890
9991 t0 = int (slices .min ())
10092 t = t0 + 1
10193 while True :
10294
103- if t in slices : # this is a slice with prompts
95+ # update the motion model
96+ motion_model = _update_motion_model (seg , t , t0 , motion_model )
97+
98+ # use the segmentation from prompts if we are in a slice with prompts
99+ if t in slices :
104100 seg_prev = None
105101 seg_t = seg [t ]
106102 track_state = prompt_layer_to_state (prompt_layer , t )
107- # TODO what do we do with the motion model here?
108103
109- else : # this is a slice without prompts
110- seg_prev , motion_model = _update_motion ( seg , t , t0 , motion_model )
104+ # otherwise project the mask (under the motion model) and segment the next slice from the mask
105+ else :
111106 if verbose :
112107 print (f"Tracking object in frame { t } with movement { motion_model } " )
108+
109+ seg_prev = seg [t - 1 ]
110+ # shift the segmentation according to the motion model
111+ if motion_model is not None :
112+ seg_prev = _shift_object (seg_prev , motion_model )
113+
113114 seg_t = segment_from_mask (predictor , seg_prev , image_embeddings = image_embeddings , i = t ,
114115 use_mask = use_mask , use_box = use_box )
115116 track_state = "track"
@@ -142,9 +143,31 @@ def _update_motion(seg, t, t0, motion_model):
142143
143144 # stop if we have a division
144145 if track_state == "division" :
146+ has_division = True
145147 break
146148
147- return seg
149+ return seg , has_division
150+
151+
152+ def _update_lineage ():
153+ global LINEAGE , TRACKING_WIDGET
154+ mother = CURRENT_TRACK_ID
155+ assert mother in LINEAGE
156+ assert len (LINEAGE [mother ]) == 0
157+
158+ daughter1 , daughter2 = CURRENT_TRACK_ID + 1 , CURRENT_TRACK_ID + 2
159+ LINEAGE [mother ] = [daughter1 , daughter2 ]
160+ LINEAGE [daughter1 ] = []
161+ LINEAGE [daughter2 ] = []
162+
163+ # update the choices in the track_id menu
164+ track_ids = list (map (str , LINEAGE .keys ()))
165+ TRACKING_WIDGET [1 ].choices = track_ids
166+
167+ # not sure if this does the right thing.
168+ # for now the user has to take care of this manually
169+ # # reset the state to track
170+ # TRACKING_WIDGET[0].set_choice("track")
148171
149172
150173#
@@ -157,11 +180,16 @@ def segment_frame_wigdet(v: Viewer):
157180 position = v .cursor .position
158181 t = int (position [0 ])
159182
160- this_prompts = prompt_layer_to_points (v .layers ["prompts" ], t )
183+ this_prompts = prompt_layer_to_points (v .layers ["prompts" ], t , track_id = CURRENT_TRACK_ID )
161184 points , labels = this_prompts
162185 seg = segment_from_points (PREDICTOR , points , labels , image_embeddings = IMAGE_EMBEDDINGS , i = t )
163186
164- v .layers ["current_track" ].data [t ] = seg .squeeze ()
187+ # clear the old segmentation for this track_id
188+ old_mask = v .layers ["current_track" ].data [t ] == CURRENT_TRACK_ID
189+ v .layers ["current_track" ].data [t ][old_mask ] = 0
190+ # set the new segmentation
191+ new_mask = seg .squeeze () == 1
192+ v .layers ["current_track" ].data [t ][new_mask ] = CURRENT_TRACK_ID
165193 v .layers ["current_track" ].refresh ()
166194
167195
@@ -181,26 +209,107 @@ def track_objet_widget(
181209 with progress (total = shape [0 ]) as progress_bar :
182210 # step 1: segment all slices with prompts
183211 seg , slices , _ , stop_upper = segment_slices_with_prompts (
184- PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape , progress_bar = progress_bar
212+ PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape ,
213+ progress_bar = progress_bar , track_id = CURRENT_TRACK_ID
185214 )
186215
187216 # step 2: track the object starting from the lowest annotated slice
188- seg = _track_from_prompts (
217+ seg , has_division = _track_from_prompts (
189218 v .layers ["prompts" ], seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
190219 projection = projection_ , progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
191220 )
192221
193- v .layers ["current_track" ].data = seg
222+ # if a division has occurred and it's the first time it occurred for this track
223+ # we need to create the two daughter tracks and update the lineage
224+ if has_division and (len (LINEAGE [CURRENT_TRACK_ID ]) == 0 ):
225+ _update_lineage ()
226+
227+ # clear the old track mask
228+ v .layers ["current_track" ].data [v .layers ["current_track" ].data == CURRENT_TRACK_ID ] = 0
229+ # set the new track mask
230+ v .layers ["current_track" ].data [seg == 1 ] = CURRENT_TRACK_ID
194231 v .layers ["current_track" ].refresh ()
195232
196233
234+ def create_tracking_menu (points_layer , states , track_ids ):
235+ state_menu = ComboBox (label = "track_state" , choices = states )
236+ track_id_menu = ComboBox (label = "track_id" , choices = list (map (str , track_ids )))
237+ tracking_widget = Container (widgets = [state_menu , track_id_menu ])
238+
239+ def update_state (event ):
240+ new_state = str (points_layer .current_properties ["state" ][0 ])
241+ if new_state != state_menu .value :
242+ state_menu .value = new_state
243+
244+ def update_track_id (event ):
245+ global CURRENT_TRACK_ID
246+ new_id = str (points_layer .current_properties ["track_id" ][0 ])
247+ if new_id != track_id_menu .value :
248+ state_menu .value = new_id
249+ CURRENT_TRACK_ID = int (new_id )
250+
251+ points_layer .events .current_properties .connect (update_state )
252+ points_layer .events .current_properties .connect (update_track_id )
253+
254+ def state_changed (new_state ):
255+ current_properties = points_layer .current_properties
256+ current_properties ["state" ] = np .array ([new_state ])
257+ points_layer .current_properties = current_properties
258+ points_layer .refresh_colors ()
259+
260+ def track_id_changed (new_track_id ):
261+ global CURRENT_TRACK_ID
262+ current_properties = points_layer .current_properties
263+ current_properties ["track_id" ] = np .array ([new_track_id ])
264+ points_layer .current_properties = current_properties
265+ CURRENT_TRACK_ID = int (new_track_id )
266+
267+ state_menu .changed .connect (state_changed )
268+ track_id_menu .changed .connect (track_id_changed )
269+
270+ state_menu .set_choice ("track" )
271+ return tracking_widget
272+
273+
274+ @magicgui (call_button = "Commit [C]" , layer = {"choices" : ["current_track" ]})
275+ def commit_tracking_widget (v : Viewer , layer : str = "current_track" ):
276+ global CURRENT_TRACK_ID , LINEAGE , TRACKING_WIDGET
277+
278+ seg = v .layers [layer ].data
279+
280+ id_offset = int (v .layers ["committed_tracks" ].data .max ())
281+ mask = seg != 0
282+
283+ v .layers ["committed_tracks" ].data [mask ] = (seg [mask ] + id_offset )
284+ v .layers ["committed_tracks" ].refresh ()
285+
286+ # reset the lineage and track id
287+ CURRENT_TRACK_ID = 1
288+ LINEAGE = {1 : []}
289+
290+ # reset the choices in the track_id menu
291+ track_ids = list (map (str , LINEAGE .keys ()))
292+ TRACKING_WIDGET [1 ].choices = track_ids
293+
294+ shape = v .layers ["raw" ].data .shape
295+ v .layers [layer ].data = np .zeros (shape , dtype = "uint32" )
296+ v .layers [layer ].refresh ()
297+
298+ v .layers ["prompts" ].data = []
299+ v .layers ["prompts" ].refresh ()
300+
301+
197302def annotator_tracking (raw , embedding_path = None , show_embeddings = False ):
198- # for access to the predictor and the image embeddings in the widgets
199- global PREDICTOR , IMAGE_EMBEDDINGS , NEXT_ID
200- NEXT_ID = 1
303+ # global state
304+ global PREDICTOR , IMAGE_EMBEDDINGS , CURRENT_TRACK_ID , LINEAGE
305+ global TRACKING_WIDGET
306+
201307 PREDICTOR = util .get_sam_model ()
202308 IMAGE_EMBEDDINGS = util .precompute_image_embeddings (PREDICTOR , raw , save_path = embedding_path )
203309
310+ CURRENT_TRACK_ID = 1
311+ LINEAGE = {1 : []}
312+
204313 #
205314 # initialize the viewer and add layers
206315 #
@@ -227,7 +336,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
227336 properties = {
228337 "label" : labels ,
229338 "state" : state_labels ,
230- # "track_id": [1, 1],
339+ "track_id" : ["1" , "1" ], # NOTE we use string to avoid pandas warnings...
231340 },
232341 edge_color = "label" ,
233342 edge_color_cycle = LABEL_COLOR_CYCLE ,
@@ -250,11 +359,12 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
250359 prompt_widget = create_prompt_menu (prompts , labels )
251360 v .window .add_dock_widget (prompt_widget )
252361
253- state_widget = create_prompt_menu (prompts , state_labels , menu_name = "state" , label_name = "state" )
254- v .window .add_dock_widget (state_widget )
362+ TRACKING_WIDGET = create_tracking_menu (prompts , state_labels , list ( LINEAGE . keys ()) )
363+ v .window .add_dock_widget (TRACKING_WIDGET )
255364
256365 v .window .add_dock_widget (segment_frame_wigdet )
257366 v .window .add_dock_widget (track_objet_widget )
367+ v .window .add_dock_widget (commit_tracking_widget )
258368
259369 #
260370 # key bindings
@@ -272,6 +382,10 @@ def _track_object(v):
272382 def _toggle_label (event = None ):
273383 toggle_label (prompts )
274384
385+ @v .bind_key ("c" )
386+ def _commit (v ):
387+ commit_tracking_widget (v )
388+
275389 @v .bind_key ("Shift-C" )
276390 def clear_prompts (v ):
277391 prompts .data = []
0 commit comments