Skip to content

Commit e2eaa08

Browse files
committed
Less broken annotator
1 parent 136f528 commit e2eaa08

File tree

2 files changed

+66
-130
lines changed

2 files changed

+66
-130
lines changed

easyhec/segmentation/interactive.py

Lines changed: 64 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -2,62 +2,13 @@
22
Tools for interactive segmentation
33
"""
44

5-
import sys
6-
5+
import cv2
76
import matplotlib.pyplot as plt
87
import numpy as np
98
import torch
109
from matplotlib import pyplot as plt
1110

1211

13-
class ImageRenderer:
14-
def __init__(self, wait_for_button_press=True):
15-
"""
16-
Create a very light-weight image renderer.
17-
18-
Args:
19-
wait_for_button_press (bool): If True, each call to this renderer will pause the process until the user presses any key.
20-
event_handler: Code to run given an event / button press. If None the default is mapping 'escape' and 'q' to sys.exit(0)
21-
"""
22-
self._image = None
23-
self.last_event = None
24-
self.wait_for_button_press = wait_for_button_press
25-
self.pressed_keys = set()
26-
27-
def key_press_handler(self, event):
28-
self.last_event = event
29-
self.pressed_keys.add(event.key)
30-
if event.key in ["q", "escape"]:
31-
sys.exit(0)
32-
33-
def key_release_handler(self, event):
34-
if event.key in self.pressed_keys:
35-
self.pressed_keys.remove(event.key)
36-
37-
def __call__(self, buffer):
38-
if not self._image:
39-
plt.ion()
40-
self.fig, self.ax = plt.subplots()
41-
self._image = self.ax.imshow(buffer, animated=True)
42-
self.fig.canvas.mpl_connect("key_press_event", self.key_press_handler)
43-
self.fig.canvas.mpl_connect("key_release_event", self.key_release_handler)
44-
else:
45-
self._image.set_data(buffer)
46-
if self.wait_for_button_press:
47-
plt.waitforbuttonpress()
48-
else:
49-
self.fig.canvas.draw_idle()
50-
self.fig.canvas.flush_events()
51-
plt.draw()
52-
53-
def __del__(self):
54-
self.close()
55-
56-
def close(self):
57-
plt.ioff()
58-
plt.close()
59-
60-
6112
class InteractiveSegmentation:
6213
"""
6314
Interactive segmentation tool. Opens a window from which you can click to record pixel positions.
@@ -94,68 +45,67 @@ def get_segmentation(self, images: np.ndarray):
9445
9546
There are a few other options that let the user e.g. redo the segmentation, redo the points etc., see the terminal output for help
9647
"""
97-
renderer = ImageRenderer(wait_for_button_press=False)
9848
state = "annotation"
9949
current_image_idx = 0
10050
masks = []
101-
annotation_objs = []
10251
clicked_points = []
10352

104-
def print_help_message():
105-
if state == "annotation":
106-
print(
107-
f"Currently annotating image {current_image_idx+1}/{len(images)}. Click to add a point of what to segment, right click to add a negative point of what not to segment. Press 't' when done. Press 'r' to clear the current point annotation and redo the points"
108-
)
109-
elif state == "segmentation":
110-
print(
111-
f"Currently showing the predicted segmentation for image {current_image_idx+1}/{len(images)}. Press 't' to move on to the next image. Press 'e' to delete this segmentation and edit the existing annotation points. Press 'r' to delete this segmentation and re-annotate the points for this image."
112-
)
113-
114-
def onclick(event):
115-
nonlocal annotation_objs, clicked_points
116-
if event.xdata is not None and event.ydata is not None:
117-
x, y = int(event.xdata), int(event.ydata)
118-
if event.button == 3:
119-
clicked_points.append((x, y, 0))
120-
annotation_objs.append(plt.plot(x, y, "ro")[0])
121-
else:
122-
if x < 0 or x >= image.shape[1] or y < 0 or y >= image.shape[0]:
123-
return
124-
clicked_points.append((x, y, 1))
125-
annotation_objs.append(plt.plot(x, y, "go")[0])
53+
state = "annotation"
12654

127-
def clear_drawn_points():
128-
nonlocal annotation_objs
129-
for x in annotation_objs:
130-
x.remove()
131-
annotation_objs = []
55+
def print_help_message():
56+
print(
57+
f"Currently annotating image {current_image_idx+1}/{len(images)}. Click to add a point of what to segment, right click to add a negative point of what not to segment. Press 't' to generate a candidate segmentation mask. Press 'r' to clear the current point annotation. Press 'e' to edit the existing annotation points."
58+
)
13259

133-
renderer(images[0])
134-
renderer.ax.axis("off")
135-
cid = None
136-
print(
137-
f"Starting annotation process for {len(images)} images. Press 't' to finish annotation, 'r' to redo annotation. Press 'h' for help."
60+
def mouse_callback(event, x, y, flags, param):
61+
nonlocal clicked_points
62+
if event == cv2.EVENT_LBUTTONDOWN:
63+
clicked_points.append((x, y, 1))
64+
elif event == cv2.EVENT_RBUTTONDOWN:
65+
clicked_points.append((x, y, -1))
66+
67+
# Display the image and set mouse callback
68+
annotation_window_name = "Annotation: Click for positive points, right click for negative points. 'r' to reset, 'e' to edit, 't' to generate the segmentation"
69+
check_window_name = (
70+
"Check segmentation quality. Press 't' to proceed. Press 'e' to edit again."
13871
)
139-
print("--------------------------------")
72+
cv2.namedWindow(annotation_window_name, cv2.WINDOW_GUI_NORMAL)
73+
cv2.setMouseCallback(annotation_window_name, mouse_callback)
74+
14075
print_help_message()
76+
77+
point_size = int(0.01 * (images[0].shape[0] + images[0].shape[1]) / 2)
14178
while current_image_idx < len(images):
142-
image = images[current_image_idx].copy()
143-
key = renderer.last_event.key if renderer.last_event is not None else None
144-
if renderer.last_event is not None:
145-
renderer.last_event = None
146-
if key == "q":
147-
renderer.close()
148-
return None
149-
if key == "h":
150-
print_help_message()
79+
display_img = images[current_image_idx].copy()
80+
image = display_img.copy()
81+
key = cv2.waitKey(1)
15182
if state == "annotation":
152-
cid = renderer.fig.canvas.mpl_connect("button_press_event", onclick)
153-
renderer.ax.set_title(
154-
"Click on the image to record annotation points for segmentation"
155-
)
156-
renderer(image)
83+
if clicked_points:
84+
for x, y, label in clicked_points:
85+
cv2.circle(
86+
display_img,
87+
(x, y),
88+
point_size,
89+
(25, 200, 25) if label == 1 else (200, 25, 25),
90+
-1,
91+
)
92+
if key == ord("r"):
93+
print("(r)esetting the point annotations")
94+
clicked_points = []
95+
elif key == ord("e"):
96+
print("Entering (e)dit mode")
97+
elif key == ord("t"):
98+
if len(clicked_points) == 0:
99+
print(
100+
"No points to generate the segmentation mask. Make sure to add at least one point."
101+
)
102+
continue
103+
print(
104+
"Generating the segmentation mask, check its quality. If the mask is good press 't' again to move on."
105+
)
106+
cv2.setWindowTitle(annotation_window_name, check_window_name)
107+
state = "check"
157108

158-
if key == "t":
159109
if self.segmentation_model == "sam2":
160110
clicked_points_np = np.array(clicked_points)
161111
input_label = clicked_points_np[:, 2]
@@ -169,46 +119,31 @@ def clear_drawn_points():
169119
)
170120
mask = mask[0]
171121
state = "segmentation"
172-
clear_drawn_points()
173-
elif key == "r":
174-
clear_drawn_points()
175-
clicked_points = []
176-
print("Cleared previous points")
177122
elif state == "segmentation":
178-
renderer.fig.canvas.mpl_disconnect(cid)
179-
masked_image = image.copy()
180123
mask_color = np.array([30, 144, 255])
181124
mask_overlay = mask.astype(float).reshape(
182125
image.shape[0], image.shape[1], 1
183126
) * mask_color.reshape(1, 1, -1)
184-
masked_image = mask_overlay * 0.6 + masked_image * 0.4
185-
masked_image[mask == 0] = image[mask == 0]
186-
renderer.ax.set_title("Check the segmentation quality")
187-
renderer(masked_image.astype(np.uint8))
188-
if key == "t":
127+
display_img = mask_overlay * 0.6 + display_img * 0.4
128+
display_img[mask == 0] = image[mask == 0]
129+
display_img = display_img.astype(np.uint8)
130+
if key == ord("t"):
189131
masks.append(mask)
190132
current_image_idx += 1
191133
state = "annotation"
192-
clear_drawn_points()
193134
clicked_points = []
194135
if current_image_idx < len(images):
195136
print_help_message()
196-
elif key == "e":
197-
state = "annotation"
198-
# redraw existing points since they got removed to show the segmentation image
199-
for x in annotation_objs:
200-
x.remove()
201-
annotation_objs = []
202-
for pos in clicked_points:
203-
annotation_objs.append(
204-
renderer.ax.plot(
205-
pos[0], pos[1], "ro" if pos[2] == 0 else "go"
206-
)[0]
207-
)
208-
elif key == "r":
137+
elif key == ord("e"):
138+
print("Entering (e)dit mode")
139+
cv2.setWindowTitle(annotation_window_name, annotation_window_name)
209140
state = "annotation"
141+
elif key == ord("r"):
142+
print("(r)esetting the point annotations")
210143
clicked_points = []
211-
clear_drawn_points()
212-
print("Cleared previous points")
213-
renderer.close()
144+
state = "annotation"
145+
cv2.imshow(
146+
annotation_window_name, cv2.cvtColor(display_img, cv2.COLOR_RGB2BGR)
147+
)
148+
cv2.destroyWindow(annotation_window_name)
214149
return np.stack(masks)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="easyhec",
5-
version="0.1.7",
5+
version="0.1.8",
66
packages=find_packages(),
77
package_data={"easyhec": ["examples/real/robot_definitions/**"]},
88
author="Stone Tao",
@@ -20,6 +20,7 @@
2020
"transforms3d",
2121
"matplotlib",
2222
"urchin",
23+
"opencv-python",
2324
# ninja is used by nvdiffrast
2425
"ninja>=1.11",
2526
],

0 commit comments

Comments
 (0)