1111
1212from .. import util
1313from ..segment_from_prompts import segment_from_mask , segment_from_points
14- from .util import create_prompt_menu , prompt_layer_to_points , segment_slices_with_prompts , LABEL_COLOR_CYCLE
14+ from .util import (
15+ create_prompt_menu , prompt_layer_to_points , prompt_layer_to_state , segment_slices_with_prompts , LABEL_COLOR_CYCLE
16+ )
1517from ..visualization import project_embeddings_for_visualization
1618
1719# Magenta and Cyan
@@ -58,7 +60,8 @@ def _shift_object(mask, motion_model):
5860
5961# TODO handle divison annotations + division classifier
6062def _track_from_prompts (
61- seg , predictor , slices , image_embeddings , stop_upper , threshold , projection ,
63+ prompt_layer , seg , predictor , slices , image_embeddings ,
64+ stop_upper , threshold , projection ,
6265 progress_bar = None , motion_smoothing = 0.5 ,
6366):
6467 assert projection in ("mask" , "bounding_box" )
@@ -95,16 +98,28 @@ def _update_motion(seg, t, t0, motion_model):
9598 t0 = int (slices .min ())
9699 t = t0 + 1
97100 while True :
98- if t in slices :
101+
102+ if t in slices : # this is a slice with prompts
99103 seg_prev = None
100104 seg_t = seg [t ]
101- else :
105+ track_state = prompt_layer_to_state (prompt_layer , t )
106+ # TODO what do we do with the motion model here?
107+
108+ else : # this is a slice without prompts
102109 seg_prev , motion_model = _update_motion (seg , t , t0 , motion_model )
110+ if verbose :
111+ print (f"Tracking object in frame { t } with movement { motion_model } " )
103112 seg_t = segment_from_mask (predictor , seg_prev , image_embeddings = image_embeddings , i = t ,
104113 use_mask = use_mask , use_box = use_box )
114+ track_state = "track"
115+
116+ # are we beyond the last slice with prompt?
117+ # if no: we continue tracking because we know we need to connect to a future frame
118+ # if yes: we only continue tracking if overlaps are above the threshold
119+ if t < slices [- 1 ]:
120+ seg_prev = None
121+
105122 _update_progress ()
106- if verbose :
107- print (f"Tracking object in frame { t } with movement { motion_model } " )
108123
109124 if (threshold is not None ) and (seg_prev is not None ):
110125 iou = util .compute_iou (seg_prev , seg_t )
@@ -116,7 +131,6 @@ def _update_motion(seg, t, t0, motion_model):
116131 seg [t ] = seg_t
117132 t += 1
118133
119- # TODO here we need to also stop once divisions are implemented
120134 # stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track)
121135 if t == slices [- 1 ] and stop_upper :
122136 break
@@ -125,6 +139,10 @@ def _update_motion(seg, t, t0, motion_model):
125139 if t == seg .shape [0 ] - 1 :
126140 break
127141
142+ # stop if we have a division
143+ if track_state == "division" :
144+ break
145+
128146 return seg
129147
130148
@@ -167,7 +185,7 @@ def track_objet_widget(
167185
168186 # step 2: track the object starting from the lowest annotated slice
169187 seg = _track_from_prompts (
170- seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
188+ v . layers [ "prompts" ], seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
171189 projection = projection_ , progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
172190 )
173191
0 commit comments