1111# from vigra.filters import eccentricityCenters
1212
1313from .. import util
14- from ..segment_from_prompts import segment_from_mask , segment_from_points
14+ from ..segment_from_prompts import segment_from_mask
1515from .util import (
16- create_prompt_menu , prompt_layer_to_points , prompt_layer_to_state ,
16+ create_prompt_menu , clear_all_prompts ,
17+ prompt_layer_to_boxes , prompt_layer_to_points ,
18+ prompt_layer_to_state , prompt_segmentation ,
1719 segment_slices_with_prompts , toggle_label , LABEL_COLOR_CYCLE
1820)
1921from ..visualization import project_embeddings_for_visualization
2022
21- # Magenta and Cyan
22- STATE_COLOR_CYCLE = ["#FF00FF " , "#00FFFF" ]
23+ # Cyan (track) and Magenta (division)
24+ STATE_COLOR_CYCLE = ["#00FFFF " , "#FF00FF" , ]
2325
2426
2527#
@@ -56,7 +58,7 @@ def _shift_object(mask, motion_model):
5658
5759# TODO division classifier
5860def _track_from_prompts (
59- prompt_layer , seg , predictor , slices , image_embeddings ,
61+ point_prompts , box_prompts , seg , predictor , slices , image_embeddings ,
6062 stop_upper , threshold , projection ,
6163 progress_bar = None , motion_smoothing = 0.5 ,
6264):
@@ -99,7 +101,9 @@ def _update_motion_model(seg, t, t0, motion_model):
99101 if t in slices :
100102 seg_prev = None
101103 seg_t = seg [t ]
102- track_state = prompt_layer_to_state (prompt_layer , t )
104+ # currently using the box layer doesn't work for keeping track of the track state
105+ # track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
106+ track_state = prompt_layer_to_state (point_prompts , t )
103107
104108 # otherwise project the mask (under the motion model) and segment the next slice from the mask
105109 else :
@@ -138,7 +142,7 @@ def _update_motion_model(seg, t, t0, motion_model):
138142 break
139143
140144 # stop if we are at the last slce
141- if t == seg .shape [0 ] - 1 :
145+ if t == seg .shape [0 ]:
142146 break
143147
144148 # stop if we have a division
@@ -180,9 +184,24 @@ def segment_frame_wigdet(v: Viewer):
180184 position = v .cursor .position
181185 t = int (position [0 ])
182186
183- this_prompts = prompt_layer_to_points (v .layers ["prompts" ], t , track_id = CURRENT_TRACK_ID )
184- points , labels = this_prompts
185- seg = segment_from_points (PREDICTOR , points , labels , image_embeddings = IMAGE_EMBEDDINGS , i = t )
187+ point_prompts = prompt_layer_to_points (v .layers ["prompts" ], t , track_id = CURRENT_TRACK_ID )
188+ # this is a stop prompt, we do nothing
189+ if not point_prompts :
190+ return
191+
192+ boxes = prompt_layer_to_boxes (v .layers ["box_prompts" ], t , track_id = CURRENT_TRACK_ID )
193+ points , labels = point_prompts
194+
195+ shape = v .layers ["current_track" ].data .shape [1 :]
196+ seg = prompt_segmentation (
197+ PREDICTOR , points , labels , boxes , shape , multiple_box_prompts = False ,
198+ image_embeddings = IMAGE_EMBEDDINGS , i = t
199+ )
200+
201+ # no prompts were given or prompts were invalid, skip segmentation
202+ if seg is None :
203+ print ("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped." )
204+ return
186205
187206 # clear the old segmentation for this track_id
188207 old_mask = v .layers ["current_track" ].data [t ] == CURRENT_TRACK_ID
@@ -199,24 +218,23 @@ def track_objet_widget(
199218):
200219 shape = v .layers ["raw" ].data .shape
201220
202- # choose mask projection for square images and bounding box projection otherwise
203- # (because mask projection does not work properly for non-square images yet)
204- if projection == "default" :
205- projection_ = "mask" if shape [1 ] == shape [2 ] else "bounding_box"
206- else :
207- projection_ = projection
221+ # we use the bounding box projection method as default which generally seems to work better for larger changes
222+ # between frames (which is pretty tyipical for tracking compared to 3d segmentation)
223+ projection_ = "bounding_box" if projection == "default" else projection
208224
209225 with progress (total = shape [0 ]) as progress_bar :
210226 # step 1: segment all slices with prompts
211227 seg , slices , _ , stop_upper = segment_slices_with_prompts (
212- PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape ,
228+ PREDICTOR , v .layers ["prompts" ], v . layers [ "box_prompts" ], IMAGE_EMBEDDINGS , shape ,
213229 progress_bar = progress_bar , track_id = CURRENT_TRACK_ID
214230 )
215231
216232 # step 2: track the object starting from the lowest annotated slice
217233 seg , has_division = _track_from_prompts (
218- v .layers ["prompts" ], seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
219- projection = projection_ , progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
234+ v .layers ["prompts" ], v .layers ["box_prompts" ], seg ,
235+ PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper ,
236+ threshold = iou_threshold , projection = projection_ ,
237+ progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
220238 )
221239
222240 # if a division has occurred and it's the first time it occurred for this track
@@ -231,7 +249,7 @@ def track_objet_widget(
231249 v .layers ["current_track" ].refresh ()
232250
233251
234- def create_tracking_menu (points_layer , states , track_ids ):
252+ def create_tracking_menu (points_layer , box_layer , states , track_ids ):
235253 state_menu = ComboBox (label = "track_state" , choices = states )
236254 track_id_menu = ComboBox (label = "track_id" , choices = list (map (str , track_ids )))
237255 tracking_widget = Container (widgets = [state_menu , track_id_menu ])
@@ -245,11 +263,25 @@ def update_track_id(event):
245263 global CURRENT_TRACK_ID
246264 new_id = str (points_layer .current_properties ["track_id" ][0 ])
247265 if new_id != track_id_menu .value :
248- state_menu .value = new_id
266+ track_id_menu .value = new_id
267+ CURRENT_TRACK_ID = int (new_id )
268+
269+ # def update_state_boxes(event):
270+ # new_state = str(box_layer.current_properties["state"][0])
271+ # if new_state != state_menu.value:
272+ # state_menu.value = new_state
273+
274+ def update_track_id_boxes (event ):
275+ global CURRENT_TRACK_ID
276+ new_id = str (box_layer .current_properties ["track_id" ][0 ])
277+ if new_id != track_id_menu .value :
278+ track_id_menu .value = new_id
249279 CURRENT_TRACK_ID = int (new_id )
250280
251281 points_layer .events .current_properties .connect (update_state )
252282 points_layer .events .current_properties .connect (update_track_id )
283+ # box_layer.events.current_properties.connect(update_state_boxes)
284+ box_layer .events .current_properties .connect (update_track_id_boxes )
253285
254286 def state_changed (new_state ):
255287 current_properties = points_layer .current_properties
@@ -264,8 +296,23 @@ def track_id_changed(new_track_id):
264296 points_layer .current_properties = current_properties
265297 CURRENT_TRACK_ID = int (new_track_id )
266298
299+ # def state_changed_boxes(new_state):
300+ # current_properties = box_layer.current_properties
301+ # current_properties["state"] = np.array([new_state])
302+ # box_layer.current_properties = current_properties
303+ # box_layer.refresh_colors()
304+
305+ def track_id_changed_boxes (new_track_id ):
306+ global CURRENT_TRACK_ID
307+ current_properties = box_layer .current_properties
308+ current_properties ["track_id" ] = np .array ([new_track_id ])
309+ box_layer .current_properties = current_properties
310+ CURRENT_TRACK_ID = int (new_track_id )
311+
267312 state_menu .changed .connect (state_changed )
268313 track_id_menu .changed .connect (track_id_changed )
314+ # state_menu.changed.connect(state_changed_boxes)
315+ track_id_menu .changed .connect (track_id_changed_boxes )
269316
270317 state_menu .set_choice ("track" )
271318 return tracking_widget
@@ -295,8 +342,7 @@ def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
295342 v .layers [layer ].data = np .zeros (shape , dtype = "uint32" )
296343 v .layers [layer ].refresh ()
297344
298- v .layers ["prompts" ].data = []
299- v .layers ["prompts" ].refresh ()
345+ clear_all_prompts (v )
300346
301347
302348def annotator_tracking (raw , embedding_path = None , show_embeddings = False , tracking_result = None , model_type = "vit_h" ):
@@ -333,7 +379,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
333379 # add the widgets
334380 #
335381 labels = ["positive" , "negative" ]
336- state_labels = ["division " , "track " ]
382+ state_labels = ["track " , "division " ]
337383 prompts = v .add_points (
338384 data = [[0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 ]], # FIXME workaround
339385 name = "prompts" ,
@@ -354,6 +400,24 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
354400 prompts .edge_color_mode = "cycle"
355401 prompts .face_color_mode = "cycle"
356402
403+ # using the box layer to set divisions currently doesn't work
404+ # (and setting new track ids also doesn't work, but keeping track of them in the properties is working)
405+ box_prompts = v .add_shapes (
406+ data = [
407+ np .array ([[0 , 0 , 0 ], [0 , 0 , 10 ], [0 , 10 , 0 ], [0 , 10 , 10 ]]),
408+ np .array ([[0 , 0 , 0 ], [0 , 0 , 11 ], [0 , 11 , 0 ], [0 , 11 , 11 ]]),
409+ ], # FIXME workaround
410+ shape_type = "rectangle" , # FIXME workaround
411+ edge_width = 4 , ndim = 3 ,
412+ face_color = "transparent" ,
413+ name = "box_prompts" ,
414+ edge_color = "green" ,
415+ properties = {"track_id" : ["1" , "1" ]},
416+ # properties={"track_id": ["1", "1"], "state": state_labels},
417+ # edge_color_cycle=STATE_COLOR_CYCLE,
418+ )
419+ # box_prompts.edge_color_mode = "cycle"
420+
357421 #
358422 # add the widgets
359423 #
@@ -363,7 +427,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
363427 prompt_widget = create_prompt_menu (prompts , labels )
364428 v .window .add_dock_widget (prompt_widget )
365429
366- TRACKING_WIDGET = create_tracking_menu (prompts , state_labels , list (LINEAGE .keys ()))
430+ TRACKING_WIDGET = create_tracking_menu (prompts , box_prompts , state_labels , list (LINEAGE .keys ()))
367431 v .window .add_dock_widget (TRACKING_WIDGET )
368432
369433 v .window .add_dock_widget (segment_frame_wigdet )
@@ -392,8 +456,7 @@ def _commit(v):
392456
393457 @v .bind_key ("Shift-C" )
394458 def clear_prompts (v ):
395- prompts .data = []
396- prompts .refresh ()
459+ clear_all_prompts (v )
397460
398461 #
399462 # start the viewer
0 commit comments