44from magicgui .widgets import ComboBox , Container
55from napari import Viewer
66
7- from ..segment_from_prompts import segment_from_points
7+ from ..segment_from_prompts import segment_from_box , segment_from_box_and_points , segment_from_points
88
99# Green and Red
1010LABEL_COLOR_CYCLE = ["#00FF00" , "#FF0000" ]
1111
1212
13+ def clear_all_prompts (v ):
14+ v .layers ["prompts" ].data = []
15+ v .layers ["prompts" ].refresh ()
16+ if "box_prompts" in v .layers :
17+ v .layers ["box_prompts" ].data = []
18+ v .layers ["box_prompts" ].refresh ()
19+
20+
1321@magicgui (call_button = "Commit [C]" , layer = {"choices" : ["current_object" , "auto_segmentation" ]})
1422def commit_segmentation_widget (v : Viewer , layer : str = "current_object" ):
1523 seg = v .layers [layer ].data
@@ -25,11 +33,7 @@ def commit_segmentation_widget(v: Viewer, layer: str = "current_object"):
2533 v .layers [layer ].refresh ()
2634
2735 if layer == "current_object" :
28- v .layers ["prompts" ].data = []
29- v .layers ["prompts" ].refresh ()
30- if "box_prompts" in v .layers :
31- v .layers ["box_prompts" ].data = []
32- v .layers ["box_prompts" ].refresh ()
36+ clear_all_prompts (v )
3337
3438
3539def create_prompt_menu (points_layer , labels , menu_name = "prompt" , label_name = "label" ):
@@ -59,7 +63,8 @@ def prompt_layer_to_points(prompt_layer, i=None, track_id=None):
5963
6064 Arguments:
6165 prompt_layer: the point layer
62- i [int] - index for the data (required for 3d data)
66+ i [int] - index for the data (required for 3d or timeseries data)
67+ track_id [int] - id of the current track (required for tracking data)
6368 """
6469
6570 points = prompt_layer .data
@@ -92,6 +97,50 @@ def prompt_layer_to_points(prompt_layer, i=None, track_id=None):
9297 return this_points , this_labels
9398
9499
100+ def prompt_layer_to_boxes (prompt_layer , i = None , track_id = None ):
101+ """Extract box prompts for SAM from shape layer.
102+
103+ Arguments:
104+ prompt_layer: the point layer
105+ i [int] - index for the data (required for 3d or timeseries data)
106+ track_id [int] - id of the current track (required for tracking data)
107+ """
108+ shape_data = prompt_layer .data
109+ shape_types = prompt_layer .shape_type
110+ assert len (shape_data ) == len (shape_types )
111+
112+ if i is None :
113+ # select all boxes that are rectangles
114+ boxes = [data for data , stype in zip (shape_data , shape_types ) if stype == "rectangle" ]
115+ else :
116+ # we are currently only supporting rectangle shapes.
117+ # other shapes could be supported by providing them as rough mask
118+ # (and also providing the corresponding bounding box)
119+ # but for this we need to figure out the mask prompts for non-square shapes
120+ non_rectangle = [stype != "rectangle" for stype in shape_types ]
121+ if any (non_rectangle ):
122+ print (f"You have provided { sum (non_rectangle )} shapes that are not rectangles." )
123+ print ("We currently do not support these as prompts and they will be ignored." )
124+ boxes = [
125+ data [:, 1 :] for data , stype in zip (shape_data , shape_types )
126+ if (stype == "rectangle" and (data [:, 0 ] == i ).all ())
127+ ]
128+
129+ # TODO support for track_id
130+ # if track_id is not None:
131+ # assert i is not None
132+ # track_ids = np.array(list(map(int, prompt_layer.properties["track_id"])))[mask]
133+ # track_id_mask = track_ids == track_id
134+ # this_labels, this_points = this_labels[track_id_mask], this_points[track_id_mask]
135+ # assert len(this_points) == len(this_labels)
136+
137+ # map to correct box format
138+ boxes = [
139+ np .array ([box [:, 0 ].min (), box [:, 1 ].min (), box [:, 0 ].max (), box [:, 1 ].max ()]) for box in boxes
140+ ]
141+ return boxes
142+
143+
95144def prompt_layer_to_state (prompt_layer , i ):
96145 """Get the state of the track from the prompt layer.
97146 Only relevant for annotator_tracking.
@@ -116,27 +165,36 @@ def prompt_layer_to_state(prompt_layer, i):
116165 return "track"
117166
118167
119- def segment_slices_with_prompts (predictor , prompt_layer , image_embeddings , shape , progress_bar = None , track_id = None ):
168+ def segment_slices_with_prompts (
169+ predictor , point_prompts , box_prompts , image_embeddings , shape , progress_bar = None , track_id = None
170+ ):
171+ """
172+ """
173+ assert len (shape ) == 3
174+ image_shape = shape [1 :]
120175 seg = np .zeros (shape , dtype = "uint32" )
121176
122- z_values = prompt_layer .data [:, 0 ]
177+ z_values = point_prompts .data [:, 0 ]
178+ z_values_boxes = np .concatenate ([box [:, 0 ] for box in box_prompts .data ])
179+
180+ # TODO add track id properties to boxes as well, filter z_values_boxes accordingly
123181 if track_id is not None :
124- track_ids = np .array (list (map (int , prompt_layer .properties ["track_id" ])))
182+ track_ids = np .array (list (map (int , point_prompts .properties ["track_id" ])))
125183 assert len (track_ids ) == len (z_values )
126184 z_values = z_values [track_ids == track_id ]
127185
128- slices = np .unique (z_values ).astype ("int" )
186+ slices = np .unique (np . concatenate ([ z_values , z_values_boxes ]) ).astype ("int" )
129187 stop_lower , stop_upper = False , False
130188
131189 def _update_progress ():
132190 if progress_bar is not None :
133191 progress_bar .update (1 )
134192
135193 for i in slices :
136- prompts_i = prompt_layer_to_points (prompt_layer , i , track_id )
194+ points_i = prompt_layer_to_points (point_prompts , i , track_id )
137195
138196 # do we end the segmentation at the outer slices?
139- if prompts_i is None :
197+ if points_i is None :
140198
141199 if i == slices [0 ]:
142200 stop_lower = True
@@ -149,14 +207,71 @@ def _update_progress():
149207 _update_progress ()
150208 continue
151209
152- points , labels = prompts_i
153- seg_i = segment_from_points (predictor , points , labels , image_embeddings = image_embeddings , i = i )
210+ boxes = prompt_layer_to_boxes (box_prompts , i , track_id )
211+ points , labels = points_i
212+
213+ seg_i = prompt_segmentation (
214+ predictor , points , labels , boxes , image_shape , multiple_box_prompts = False ,
215+ image_embeddings = image_embeddings , i = i
216+ )
217+ if seg_i is None :
218+ print (f"The prompts at slice or frame { i } are invalid and the segmentation was skipped." )
219+ print ("This will lead to a wrong segmentation across slices or frames." )
220+ print (f"Please correct the prompts in { i } and rerun the segmentation." )
221+ continue
222+
154223 seg [i ] = seg_i
155224 _update_progress ()
156225
157226 return seg , slices , stop_lower , stop_upper
158227
159228
229+ def prompt_segmentation (
230+ predictor , points , labels , boxes , shape , multiple_box_prompts , image_embeddings = None , i = None
231+ ):
232+ """
233+ """
234+ assert len (points ) == len (labels )
235+ have_points = len (points ) > 0
236+ have_boxes = len (boxes ) > 0
237+
238+ # no prompts were given, return None
239+ if not have_points and not have_boxes :
240+ return
241+
242+ # box and ppint prompts were given
243+ elif have_points and have_boxes :
244+ if len (boxes ) > 1 :
245+ print ("You have provided point prompts and more than one box prompt." )
246+ print ("This setting is currently not supported." )
247+ print ("When providing both points and prompts you can only segment one object at a time." )
248+ return
249+ seg = segment_from_box_and_points (
250+ predictor , boxes [0 ], points , labels , image_embeddings = image_embeddings , i = i
251+ ).squeeze ()
252+
253+ # only point prompts were given
254+ elif have_points and not have_boxes :
255+ seg = segment_from_points (predictor , points , labels , image_embeddings = image_embeddings , i = i ).squeeze ()
256+
257+ # only box prompts were given
258+ elif not have_points and have_boxes :
259+ seg = np .zeros (shape , dtype = "uint32" )
260+
261+ if len (boxes ) > 1 and not multiple_box_prompts :
262+ print ("You have provided more than one box annotation. This is not yet supported in the 3d annotator." )
263+ print ("You can only segment one object at a time in 3d." )
264+ return
265+
266+ seg_id = 1
267+ for box in boxes :
268+ mask = segment_from_box (predictor , box , image_embeddings = image_embeddings , i = i ).squeeze ()
269+ seg [mask ] = seg_id
270+ seg_id += 1
271+
272+ return seg
273+
274+
160275def toggle_label (prompts ):
161276 # get the currently selected label
162277 current_properties = prompts .current_properties
0 commit comments