22import numpy as np
33
44from magicgui import magicgui
5+ from magicgui .widgets import ComboBox , Container
56from napari import Viewer
67from napari .utils import progress
78from scipy .ndimage import shift
@@ -47,12 +48,6 @@ def compute_center(t):
4748 return move .astype ("float64" )
4849
4950
50- def _update_motion_model (motion_model , move , motion_smoothing ):
51- alpha = motion_smoothing
52- motion_model = alpha * motion_model + (1 - alpha ) * move
53- return motion_model
54-
55-
5651def _shift_object (mask , motion_model ):
5752 mask_shifted = np .zeros_like (mask )
5853 shift (mask , motion_model , output = mask_shifted , order = 0 , prefilter = False )
@@ -76,40 +71,46 @@ def _update_progress():
7671 progress_bar .update (1 )
7772
7873 # shift the segmentation based on the motion model and update the motion model
79- def _update_motion (seg , t , t0 , motion_model ):
80- seg_prev = seg [t - 1 ]
81-
82- if t == t0 + 1 : # this is the second frame, we don't have a motion model yet
74+ def _update_motion_model (seg , t , t0 , motion_model ):
75+ if t in (t0 , t0 + 1 ): # this is the first or second frame, we don't have a motion yet
8376 pass
8477 elif t == t0 + 2 : # this the third frame, we initialize the motion model
8578 current_move = _compute_movement (seg , t - 1 , t - 2 )
8679 motion_model = current_move
8780 else : # we already have a motion model and update it
8881 current_move = _compute_movement (seg , t - 1 , t - 2 )
89- motion_model = _update_motion_model (motion_model , current_move , motion_smoothing )
82+ alpha = motion_smoothing
83+ motion_model = alpha * motion_model + (1 - alpha ) * current_move
9084
91- if motion_model is not None : # shift the segmentation according to the motion model
92- seg_prev = _shift_object (seg_prev , motion_model )
93-
94- return seg_prev , motion_model
85+ return motion_model
9586
87+ has_division = False
9688 motion_model = None
9789 verbose = False
9890
9991 t0 = int (slices .min ())
10092 t = t0 + 1
10193 while True :
10294
103- if t in slices : # this is a slice with prompts
95+ # update the motion model
96+ motion_model = _update_motion_model (seg , t , t0 , motion_model )
97+
98+ # use the segmentation from prompts if we are in a slice with prompts
99+ if t in slices :
104100 seg_prev = None
105101 seg_t = seg [t ]
106102 track_state = prompt_layer_to_state (prompt_layer , t )
107- # TODO what do we do with the motion model here?
108103
109- else : # this is a slice without prompts
110- seg_prev , motion_model = _update_motion ( seg , t , t0 , motion_model )
104+ # otherwise project the mask (under the motion model) and segment the next slice from the mask
105+ else :
111106 if verbose :
112107 print (f"Tracking object in frame { t } with movement { motion_model } " )
108+
109+ seg_prev = seg [t - 1 ]
110+ # shift the segmentation according to the motion model
111+ if motion_model is not None :
112+ seg_prev = _shift_object (seg_prev , motion_model )
113+
113114 seg_t = segment_from_mask (predictor , seg_prev , image_embeddings = image_embeddings , i = t ,
114115 use_mask = use_mask , use_box = use_box )
115116 track_state = "track"
@@ -142,9 +143,31 @@ def _update_motion(seg, t, t0, motion_model):
142143
143144 # stop if we have a division
144145 if track_state == "division" :
146+ has_division = True
145147 break
146148
147- return seg
149+ return seg , has_division
150+
151+
152+ def _update_lineage ():
153+ global LINEAGE , TRACKING_WIDGET
154+ mother = CURRENT_TRACK_ID
155+ assert mother in LINEAGE
156+ assert len (LINEAGE [mother ]) == 0
157+
158+ daughter1 , daughter2 = CURRENT_TRACK_ID + 1 , CURRENT_TRACK_ID + 2
159+ LINEAGE [mother ] = [daughter1 , daughter2 ]
160+ LINEAGE [daughter1 ] = []
161+ LINEAGE [daughter2 ] = []
162+
163+ # update the choices in the track_id menu
164+ track_ids = list (map (str , LINEAGE .keys ()))
165+ TRACKING_WIDGET [1 ].choices = track_ids
166+
167+ # not sure if this does the right thing.
168+ # for now the user has to take care of this manually
169+ # # reset the state to track
170+ # TRACKING_WIDGET[0].set_choice("track")
148171
149172
150173#
@@ -157,11 +180,16 @@ def segment_frame_wigdet(v: Viewer):
157180 position = v .cursor .position
158181 t = int (position [0 ])
159182
160- this_prompts = prompt_layer_to_points (v .layers ["prompts" ], t )
183+ this_prompts = prompt_layer_to_points (v .layers ["prompts" ], t , track_id = CURRENT_TRACK_ID )
161184 points , labels = this_prompts
162185 seg = segment_from_points (PREDICTOR , points , labels , image_embeddings = IMAGE_EMBEDDINGS , i = t )
163186
164- v .layers ["current_track" ].data [t ] = seg .squeeze ()
187+ # clear the old segmentation for this track_id
188+ old_mask = v .layers ["current_track" ].data [t ] == CURRENT_TRACK_ID
189+ v .layers ["current_track" ].data [t ][old_mask ] = 0
190+ # set the new segmentation
191+ new_mask = seg .squeeze () == 1
192+ v .layers ["current_track" ].data [t ][new_mask ] = CURRENT_TRACK_ID
165193 v .layers ["current_track" ].refresh ()
166194
167195
@@ -181,26 +209,79 @@ def track_objet_widget(
181209 with progress (total = shape [0 ]) as progress_bar :
182210 # step 1: segment all slices with prompts
183211 seg , slices , _ , stop_upper = segment_slices_with_prompts (
184- PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape , progress_bar = progress_bar
212+ PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape ,
213+ progress_bar = progress_bar , track_id = CURRENT_TRACK_ID
185214 )
186215
187216 # step 2: track the object starting from the lowest annotated slice
188- seg = _track_from_prompts (
217+ seg , has_division = _track_from_prompts (
189218 v .layers ["prompts" ], seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
190219 projection = projection_ , progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
191220 )
192221
193- v .layers ["current_track" ].data = seg
222+ # if a division has occurred and it's the first time it occurred for this track
223+ # we need to create the two daughter tracks and update the lineage
224+ if has_division and len (LINEAGE [CURRENT_TRACK_ID ]) == 0 :
225+ _update_lineage ()
226+
227+ # clear the old track mask
228+ v .layers ["current_track" ].data [v .layers ["current_track" ].data == CURRENT_TRACK_ID ] = 0
229+ # set the new track mask
230+ v .layers ["current_track" ].data [seg == 1 ] = CURRENT_TRACK_ID
194231 v .layers ["current_track" ].refresh ()
195232
196233
234+ def create_tracking_menu (points_layer , states , track_ids ):
235+ state_menu = ComboBox (label = "track_state" , choices = states )
236+ track_id_menu = ComboBox (label = "track_id" , choices = list (map (str , track_ids )))
237+ tracking_widget = Container (widgets = [state_menu , track_id_menu ])
238+
239+ def update_state (event ):
240+ new_state = str (points_layer .current_properties ["state" ][0 ])
241+ if new_state != state_menu .value :
242+ state_menu .value = new_state
243+
244+ def update_track_id (event ):
245+ global CURRENT_TRACK_ID
246+ new_id = str (points_layer .current_properties ["track_id" ][0 ])
247+ if new_id != track_id_menu .value :
248+ state_menu .value = new_id
249+ CURRENT_TRACK_ID = int (new_id )
250+
251+ points_layer .events .current_properties .connect (update_state )
252+ points_layer .events .current_properties .connect (update_track_id )
253+
254+ def state_changed (new_state ):
255+ current_properties = points_layer .current_properties
256+ current_properties ["state" ] = np .array ([new_state ])
257+ points_layer .current_properties = current_properties
258+ points_layer .refresh_colors ()
259+
260+ def track_id_changed (new_track_id ):
261+ global CURRENT_TRACK_ID
262+ current_properties = points_layer .current_properties
263+ current_properties ["track_id" ] = np .array ([new_track_id ])
264+ points_layer .current_properties = current_properties
265+ CURRENT_TRACK_ID = int (new_track_id )
266+
267+ state_menu .changed .connect (state_changed )
268+ track_id_menu .changed .connect (track_id_changed )
269+
270+ state_menu .set_choice ("track" )
271+ return tracking_widget
272+
273+
197274def annotator_tracking (raw , embedding_path = None , show_embeddings = False ):
198- # for access to the predictor and the image embeddings in the widgets
199- global PREDICTOR , IMAGE_EMBEDDINGS , NEXT_ID
200- NEXT_ID = 1
275+ # global state
276+ global PREDICTOR , IMAGE_EMBEDDINGS , CURRENT_TRACK_ID , LINEAGE
277+ global TRACKING_WIDGET
278+
201279 PREDICTOR = util .get_sam_model ()
202280 IMAGE_EMBEDDINGS = util .precompute_image_embeddings (PREDICTOR , raw , save_path = embedding_path )
203281
282+ CURRENT_TRACK_ID = 1
283+ LINEAGE = {1 : []}
284+
204285 #
205286 # initialize the viewer and add layers
206287 #
@@ -227,7 +308,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
227308 properties = {
228309 "label" : labels ,
229310 "state" : state_labels ,
230- # "track_id": [1, 1],
311+ "track_id" : ["1" , "1" ], # NOTE we use string to avoid pandas warnings...
231312 },
232313 edge_color = "label" ,
233314 edge_color_cycle = LABEL_COLOR_CYCLE ,
@@ -250,8 +331,8 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False):
250331 prompt_widget = create_prompt_menu (prompts , labels )
251332 v .window .add_dock_widget (prompt_widget )
252333
253- state_widget = create_prompt_menu (prompts , state_labels , menu_name = "state" , label_name = "state" )
254- v .window .add_dock_widget (state_widget )
334+ TRACKING_WIDGET = create_tracking_menu (prompts , state_labels , list ( LINEAGE . keys ()) )
335+ v .window .add_dock_widget (TRACKING_WIDGET )
255336
256337 v .window .add_dock_widget (segment_frame_wigdet )
257338 v .window .add_dock_widget (track_objet_widget )
0 commit comments