Skip to content

Commit c91993d

Browse files
Implement box prompts for the 2d annotator
1 parent 45119dd commit c91993d

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,70 @@
77
from .. import util
88
from .. import segment_instances
99
from ..visualization import project_embeddings_for_visualization
10-
from ..segment_from_prompts import segment_from_points
10+
from ..segment_from_prompts import segment_from_box, segment_from_box_and_points, segment_from_points
1111
from .util import (
1212
commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE
1313
)
1414

1515

1616
@magicgui(call_button="Segment Object [S]")
1717
def segment_wigdet(v: Viewer):
18+
# get the current point prompts
1819
points, labels = prompt_layer_to_points(v.layers["prompts"])
19-
seg = segment_from_points(PREDICTOR, points, labels)
20-
v.layers["current_object"].data = seg.squeeze()
20+
assert len(points) == len(labels)
21+
have_points = len(points) > 0
22+
23+
# get the current box prompts
24+
box_layer = v.layers["box_prompts"]
25+
have_boxes = box_layer.nshapes > 0
26+
27+
# segment only with points
28+
if have_points and not have_boxes:
29+
seg = segment_from_points(PREDICTOR, points, labels).squeeze()
30+
31+
# segment only with boxes
32+
elif not have_points and have_boxes:
33+
shape = v.layers["current_object"].data.shape
34+
seg = np.zeros(shape, dtype="uint32")
35+
36+
seg_id = 1
37+
for prompt_id in range(box_layer.nshapes):
38+
shape_type = box_layer.shape_type[prompt_id]
39+
40+
# for now we only support segmentation from rectangles.
41+
# supporting other shapes would be possible by casting the shape to a mask
42+
# and then segmenting from mask and bounding box.
43+
# but for this we need to fix issue with resizing the mask for non-square shapes.
44+
if shape_type != "rectangle":
45+
print(f"You have provided a {shape_type} shape.")
46+
print("We currently only support rectangle shapes for prompts and this prompt will be skipped.")
47+
continue
48+
49+
box = box_layer.data[prompt_id]
50+
prompt_box = np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()])
51+
mask = segment_from_box(PREDICTOR, prompt_box).squeeze()
52+
seg[mask] = seg_id
53+
seg_id += 1
54+
55+
# segment with points and box (currently only one box supported)
56+
elif have_points and have_boxes:
57+
if box_layer.nshapes > 1:
58+
print("You have provided point prompts and more than one box prompt.")
59+
print("This setting is currently not supported.")
60+
print("When providing both points and prompts you can only segment one object at a time.")
61+
return
62+
63+
box = box_layer.data[0]
64+
prompt_box = np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()])
65+
seg = segment_from_box_and_points(PREDICTOR, prompt_box, points, labels).squeeze()
66+
67+
# no prompts were given, skip segmentation
68+
else:
69+
print("You haven't given any prompts.")
70+
print("Please provide point and/or box prompts.")
71+
return
72+
73+
v.layers["current_object"].data = seg
2174
v.layers["current_object"].refresh()
2275

2376

@@ -85,6 +138,10 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
85138
)
86139
prompts.edge_color_mode = "cycle"
87140

141+
box_prompts = v.add_shapes(
142+
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts"
143+
)
144+
88145
#
89146
# add the widgets
90147
#
@@ -118,6 +175,8 @@ def _toggle_label(event=None):
118175
def clear_prompts(v):
119176
prompts.data = []
120177
prompts.refresh()
178+
box_prompts.data = []
179+
box_prompts.refresh()
121180

122181
#
123182
# start the viewer

micro_sam/sam_annotator/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def commit_segmentation_widget(v: Viewer, layer: str = "current_object"):
2727
if layer == "current_object":
2828
v.layers["prompts"].data = []
2929
v.layers["prompts"].refresh()
30+
if "box_prompts" in v.layers:
31+
v.layers["box_prompts"].data = []
32+
v.layers["box_prompts"].refresh()
3033

3134

3235
def create_prompt_menu(points_layer, labels, menu_name="prompt", label_name="label"):

0 commit comments

Comments
 (0)