1111# from vigra.filters import eccentricityCenters
1212
1313from .. import util
14- from ..segment_from_prompts import segment_from_mask , segment_from_points
14+ from ..segment_from_prompts import segment_from_mask
1515from .util import (
16- create_prompt_menu , prompt_layer_to_points , prompt_layer_to_state ,
16+ create_prompt_menu , clear_all_prompts ,
17+ prompt_layer_to_boxes , prompt_layer_to_points ,
18+ prompt_layer_to_state , prompt_segmentation ,
1719 segment_slices_with_prompts , toggle_label , LABEL_COLOR_CYCLE
1820)
1921from ..visualization import project_embeddings_for_visualization
2022
21- # Magenta and Cyan
22- STATE_COLOR_CYCLE = ["#FF00FF " , "#00FFFF" ]
23+ # Cyan (track) and Magenta (division)
24+ STATE_COLOR_CYCLE = ["#00FFFF " , "#FF00FF" , ]
2325
2426
2527#
@@ -56,7 +58,7 @@ def _shift_object(mask, motion_model):
5658
5759# TODO division classifier
5860def _track_from_prompts (
59- prompt_layer , seg , predictor , slices , image_embeddings ,
61+ point_prompts , box_prompts , seg , predictor , slices , image_embeddings ,
6062 stop_upper , threshold , projection ,
6163 progress_bar = None , motion_smoothing = 0.5 ,
6264):
@@ -99,7 +101,9 @@ def _update_motion_model(seg, t, t0, motion_model):
99101 if t in slices :
100102 seg_prev = None
101103 seg_t = seg [t ]
102- track_state = prompt_layer_to_state (prompt_layer , t )
104+ # currently using the box layer doesn't work for keeping track of the track state
105+ # track_state = prompt_layers_to_state(point_prompts, box_prompts, t)
106+ track_state = prompt_layer_to_state (point_prompts , t )
103107
104108 # otherwise project the mask (under the motion model) and segment the next slice from the mask
105109 else :
@@ -180,9 +184,24 @@ def segment_frame_wigdet(v: Viewer):
180184 position = v .cursor .position
181185 t = int (position [0 ])
182186
183- this_prompts = prompt_layer_to_points (v .layers ["prompts" ], t , track_id = CURRENT_TRACK_ID )
184- points , labels = this_prompts
185- seg = segment_from_points (PREDICTOR , points , labels , image_embeddings = IMAGE_EMBEDDINGS , i = t )
187+ point_prompts = prompt_layer_to_points (v .layers ["prompts" ], t , track_id = CURRENT_TRACK_ID )
188+ # this is a stop prompt, we do nothing
189+ if not point_prompts :
190+ return
191+
192+ boxes = prompt_layer_to_boxes (v .layers ["box_prompts" ], t , track_id = CURRENT_TRACK_ID )
193+ points , labels = point_prompts
194+
195+ shape = v .layers ["current_track" ].data .shape [1 :]
196+ seg = prompt_segmentation (
197+ PREDICTOR , points , labels , boxes , shape , multiple_box_prompts = False ,
198+ image_embeddings = IMAGE_EMBEDDINGS , i = t
199+ )
200+
201+ # no prompts were given or prompts were invalid, skip segmentation
202+ if seg is None :
203+ print ("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped." )
204+ return
186205
187206 # clear the old segmentation for this track_id
188207 old_mask = v .layers ["current_track" ].data [t ] == CURRENT_TRACK_ID
@@ -209,14 +228,16 @@ def track_objet_widget(
209228 with progress (total = shape [0 ]) as progress_bar :
210229 # step 1: segment all slices with prompts
211230 seg , slices , _ , stop_upper = segment_slices_with_prompts (
212- PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape ,
231+ PREDICTOR , v .layers ["prompts" ], v . layers [ "box_prompts" ], IMAGE_EMBEDDINGS , shape ,
213232 progress_bar = progress_bar , track_id = CURRENT_TRACK_ID
214233 )
215234
216235 # step 2: track the object starting from the lowest annotated slice
217236 seg , has_division = _track_from_prompts (
218- v .layers ["prompts" ], seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
219- projection = projection_ , progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
237+ v .layers ["prompts" ], v .layers ["box_prompts" ], seg ,
238+ PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper ,
239+ threshold = iou_threshold , projection = projection_ ,
240+ progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
220241 )
221242
222243 # if a division has occurred and it's the first time it occurred for this track
@@ -231,7 +252,7 @@ def track_objet_widget(
231252 v .layers ["current_track" ].refresh ()
232253
233254
234- def create_tracking_menu (points_layer , states , track_ids ):
255+ def create_tracking_menu (points_layer , box_layer , states , track_ids ):
235256 state_menu = ComboBox (label = "track_state" , choices = states )
236257 track_id_menu = ComboBox (label = "track_id" , choices = list (map (str , track_ids )))
237258 tracking_widget = Container (widgets = [state_menu , track_id_menu ])
@@ -245,11 +266,25 @@ def update_track_id(event):
245266 global CURRENT_TRACK_ID
246267 new_id = str (points_layer .current_properties ["track_id" ][0 ])
247268 if new_id != track_id_menu .value :
248- state_menu .value = new_id
269+ track_id_menu .value = new_id
270+ CURRENT_TRACK_ID = int (new_id )
271+
272+ # def update_state_boxes(event):
273+ # new_state = str(box_layer.current_properties["state"][0])
274+ # if new_state != state_menu.value:
275+ # state_menu.value = new_state
276+
277+ def update_track_id_boxes (event ):
278+ global CURRENT_TRACK_ID
279+ new_id = str (box_layer .current_properties ["track_id" ][0 ])
280+ if new_id != track_id_menu .value :
281+ track_id_menu .value = new_id
249282 CURRENT_TRACK_ID = int (new_id )
250283
251284 points_layer .events .current_properties .connect (update_state )
252285 points_layer .events .current_properties .connect (update_track_id )
286+ # box_layer.events.current_properties.connect(update_state_boxes)
287+ box_layer .events .current_properties .connect (update_track_id_boxes )
253288
254289 def state_changed (new_state ):
255290 current_properties = points_layer .current_properties
@@ -264,8 +299,23 @@ def track_id_changed(new_track_id):
264299 points_layer .current_properties = current_properties
265300 CURRENT_TRACK_ID = int (new_track_id )
266301
302+ # def state_changed_boxes(new_state):
303+ # current_properties = box_layer.current_properties
304+ # current_properties["state"] = np.array([new_state])
305+ # box_layer.current_properties = current_properties
306+ # box_layer.refresh_colors()
307+
308+ def track_id_changed_boxes (new_track_id ):
309+ global CURRENT_TRACK_ID
310+ current_properties = box_layer .current_properties
311+ current_properties ["track_id" ] = np .array ([new_track_id ])
312+ box_layer .current_properties = current_properties
313+ CURRENT_TRACK_ID = int (new_track_id )
314+
267315 state_menu .changed .connect (state_changed )
268316 track_id_menu .changed .connect (track_id_changed )
317+ # state_menu.changed.connect(state_changed_boxes)
318+ track_id_menu .changed .connect (track_id_changed_boxes )
269319
270320 state_menu .set_choice ("track" )
271321 return tracking_widget
@@ -295,8 +345,7 @@ def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
295345 v .layers [layer ].data = np .zeros (shape , dtype = "uint32" )
296346 v .layers [layer ].refresh ()
297347
298- v .layers ["prompts" ].data = []
299- v .layers ["prompts" ].refresh ()
348+ clear_all_prompts (v )
300349
301350
302351def annotator_tracking (raw , embedding_path = None , show_embeddings = False , tracking_result = None , model_type = "vit_h" ):
@@ -333,7 +382,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
333382 # add the widgets
334383 #
335384 labels = ["positive" , "negative" ]
336- state_labels = ["division " , "track " ]
385+ state_labels = ["track " , "division " ]
337386 prompts = v .add_points (
338387 data = [[0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 ]], # FIXME workaround
339388 name = "prompts" ,
@@ -354,6 +403,24 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
354403 prompts .edge_color_mode = "cycle"
355404 prompts .face_color_mode = "cycle"
356405
406+ # using the box layer to set divisions currently doesn't work
407+ # (and setting new track ids also doesn't work, but keeping track of them in the properties is working)
408+ box_prompts = v .add_shapes (
409+ data = [
410+ np .array ([[0 , 0 , 0 ], [0 , 0 , 10 ], [0 , 10 , 0 ], [0 , 10 , 10 ]]),
411+ np .array ([[0 , 0 , 0 ], [0 , 0 , 11 ], [0 , 11 , 0 ], [0 , 11 , 11 ]]),
412+ ], # FIXME workaround
413+ shape_type = "rectangle" , # FIXME workaround
414+ edge_width = 4 , ndim = 3 ,
415+ face_color = "transparent" ,
416+ name = "box_prompts" ,
417+ edge_color = "green" ,
418+ properties = {"track_id" : ["1" , "1" ]},
419+ # properties={"track_id": ["1", "1"], "state": state_labels},
420+ # edge_color_cycle=STATE_COLOR_CYCLE,
421+ )
422+ # box_prompts.edge_color_mode = "cycle"
423+
357424 #
358425 # add the widgets
359426 #
@@ -363,7 +430,7 @@ def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking
363430 prompt_widget = create_prompt_menu (prompts , labels )
364431 v .window .add_dock_widget (prompt_widget )
365432
366- TRACKING_WIDGET = create_tracking_menu (prompts , state_labels , list (LINEAGE .keys ()))
433+ TRACKING_WIDGET = create_tracking_menu (prompts , box_prompts , state_labels , list (LINEAGE .keys ()))
367434 v .window .add_dock_widget (TRACKING_WIDGET )
368435
369436 v .window .add_dock_widget (segment_frame_wigdet )
@@ -392,8 +459,7 @@ def _commit(v):
392459
393460 @v .bind_key ("Shift-C" )
394461 def clear_prompts (v ):
395- prompts .data = []
396- prompts .refresh ()
462+ clear_all_prompts (v )
397463
398464 #
399465 # start the viewer
0 commit comments