|
7 | 7 | from .. import util |
8 | 8 | from .. import segment_instances |
9 | 9 | from ..visualization import project_embeddings_for_visualization |
10 | | -from ..segment_from_prompts import segment_from_points |
| 10 | +from ..segment_from_prompts import segment_from_box, segment_from_box_and_points, segment_from_points |
11 | 11 | from .util import ( |
12 | 12 | commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE |
13 | 13 | ) |
14 | 14 |
|
15 | 15 |
|
16 | 16 | @magicgui(call_button="Segment Object [S]") |
17 | 17 | def segment_wigdet(v: Viewer): |
| 18 | + # get the current point prompts |
18 | 19 | points, labels = prompt_layer_to_points(v.layers["prompts"]) |
19 | | - seg = segment_from_points(PREDICTOR, points, labels) |
20 | | - v.layers["current_object"].data = seg.squeeze() |
| 20 | + assert len(points) == len(labels) |
| 21 | + have_points = len(points) > 0 |
| 22 | + |
| 23 | + # get the current box prompts |
| 24 | + box_layer = v.layers["box_prompts"] |
| 25 | + have_boxes = box_layer.nshapes > 0 |
| 26 | + |
| 27 | + # segment only with points |
| 28 | + if have_points and not have_boxes: |
| 29 | + seg = segment_from_points(PREDICTOR, points, labels).squeeze() |
| 30 | + |
| 31 | + # segment only with boxes |
| 32 | + elif not have_points and have_boxes: |
| 33 | + shape = v.layers["current_object"].data.shape |
| 34 | + seg = np.zeros(shape, dtype="uint32") |
| 35 | + |
| 36 | + seg_id = 1 |
| 37 | + for prompt_id in range(box_layer.nshapes): |
| 38 | + shape_type = box_layer.shape_type[prompt_id] |
| 39 | + |
| 40 | + # for now we only support segmentation from rectangles. |
| 41 | + # supporting other shapes would be possible by casting the shape to a mask |
| 42 | + # and then segmenting from mask and bounding box. |
| 43 | + # but for this we need to fix issue with resizing the mask for non-square shapes. |
| 44 | + if shape_type != "rectangle": |
| 45 | + print(f"You have provided a {shape_type} shape.") |
| 46 | + print("We currently only support rectangle shapes for prompts and this prompt will be skipped.") |
| 47 | + continue |
| 48 | + |
| 49 | + box = box_layer.data[prompt_id] |
| 50 | + prompt_box = np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) |
| 51 | + mask = segment_from_box(PREDICTOR, prompt_box).squeeze() |
| 52 | + seg[mask] = seg_id |
| 53 | + seg_id += 1 |
| 54 | + |
| 55 | + # segment with points and box (currently only one box supported) |
| 56 | + elif have_points and have_boxes: |
| 57 | + if box_layer.nshapes > 1: |
| 58 | + print("You have provided point prompts and more than one box prompt.") |
| 59 | + print("This setting is currently not supported.") |
| 60 | + print("When providing both points and prompts you can only segment one object at a time.") |
| 61 | + return |
| 62 | + |
| 63 | + box = box_layer.data[0] |
| 64 | + prompt_box = np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) |
| 65 | + seg = segment_from_box_and_points(PREDICTOR, prompt_box, points, labels).squeeze() |
| 66 | + |
| 67 | + # no prompts were given, skip segmentation |
| 68 | + else: |
| 69 | + print("You haven't given any prompts.") |
| 70 | + print("Please provide point and/or box prompts.") |
| 71 | + return |
| 72 | + |
| 73 | + v.layers["current_object"].data = seg |
21 | 74 | v.layers["current_object"].refresh() |
22 | 75 |
|
23 | 76 |
|
@@ -85,6 +138,10 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r |
85 | 138 | ) |
86 | 139 | prompts.edge_color_mode = "cycle" |
87 | 140 |
|
| 141 | + box_prompts = v.add_shapes( |
| 142 | + face_color="transparent", edge_color="green", edge_width=4, name="box_prompts" |
| 143 | + ) |
| 144 | + |
88 | 145 | # |
89 | 146 | # add the widgets |
90 | 147 | # |
@@ -118,6 +175,8 @@ def _toggle_label(event=None): |
118 | 175 | def clear_prompts(v): |
119 | 176 | prompts.data = [] |
120 | 177 | prompts.refresh() |
| 178 | + box_prompts.data = [] |
| 179 | + box_prompts.refresh() |
121 | 180 |
|
122 | 181 | # |
123 | 182 | # start the viewer |
|
0 commit comments