22Tools for interactive segmentation
33"""
44
5- import sys
6-
5+ import cv2
76import matplotlib .pyplot as plt
87import numpy as np
98import torch
109from matplotlib import pyplot as plt
1110
1211
13- class ImageRenderer :
14- def __init__ (self , wait_for_button_press = True ):
15- """
16- Create a very light-weight image renderer.
17-
18- Args:
19- wait_for_button_press (bool): If True, each call to this renderer will pause the process until the user presses any key.
20- event_handler: Code to run given an event / button press. If None the default is mapping 'escape' and 'q' to sys.exit(0)
21- """
22- self ._image = None
23- self .last_event = None
24- self .wait_for_button_press = wait_for_button_press
25- self .pressed_keys = set ()
26-
27- def key_press_handler (self , event ):
28- self .last_event = event
29- self .pressed_keys .add (event .key )
30- if event .key in ["q" , "escape" ]:
31- sys .exit (0 )
32-
33- def key_release_handler (self , event ):
34- if event .key in self .pressed_keys :
35- self .pressed_keys .remove (event .key )
36-
37- def __call__ (self , buffer ):
38- if not self ._image :
39- plt .ion ()
40- self .fig , self .ax = plt .subplots ()
41- self ._image = self .ax .imshow (buffer , animated = True )
42- self .fig .canvas .mpl_connect ("key_press_event" , self .key_press_handler )
43- self .fig .canvas .mpl_connect ("key_release_event" , self .key_release_handler )
44- else :
45- self ._image .set_data (buffer )
46- if self .wait_for_button_press :
47- plt .waitforbuttonpress ()
48- else :
49- self .fig .canvas .draw_idle ()
50- self .fig .canvas .flush_events ()
51- plt .draw ()
52-
53- def __del__ (self ):
54- self .close ()
55-
56- def close (self ):
57- plt .ioff ()
58- plt .close ()
59-
60-
6112class InteractiveSegmentation :
6213 """
6314 Interactive segmentation tool. Opens a window from which you can click to record pixel positions.
@@ -94,68 +45,67 @@ def get_segmentation(self, images: np.ndarray):
9445
9546 There are a few other options that let the user e.g. redo the segmentation, redo the points etc., see the terminal output for help
9647 """
97- renderer = ImageRenderer (wait_for_button_press = False )
9848 state = "annotation"
9949 current_image_idx = 0
10050 masks = []
101- annotation_objs = []
10251 clicked_points = []
10352
104- def print_help_message ():
105- if state == "annotation" :
106- print (
107- f"Currently annotating image { current_image_idx + 1 } /{ len (images )} . Click to add a point of what to segment, right click to add a negative point of what not to segment. Press 't' when done. Press 'r' to clear the current point annotation and redo the points"
108- )
109- elif state == "segmentation" :
110- print (
111- f"Currently showing the predicted segmentation for image { current_image_idx + 1 } /{ len (images )} . Press 't' to move on to the next image. Press 'e' to delete this segmentation and edit the existing annotation points. Press 'r' to delete this segmentation and re-annotate the points for this image."
112- )
113-
114- def onclick (event ):
115- nonlocal annotation_objs , clicked_points
116- if event .xdata is not None and event .ydata is not None :
117- x , y = int (event .xdata ), int (event .ydata )
118- if event .button == 3 :
119- clicked_points .append ((x , y , 0 ))
120- annotation_objs .append (plt .plot (x , y , "ro" )[0 ])
121- else :
122- if x < 0 or x >= image .shape [1 ] or y < 0 or y >= image .shape [0 ]:
123- return
124- clicked_points .append ((x , y , 1 ))
125- annotation_objs .append (plt .plot (x , y , "go" )[0 ])
53+ state = "annotation"
12654
127- def clear_drawn_points ():
128- nonlocal annotation_objs
129- for x in annotation_objs :
130- x .remove ()
131- annotation_objs = []
55+ def print_help_message ():
56+ print (
57+ f"Currently annotating image { current_image_idx + 1 } /{ len (images )} . Click to add a point of what to segment, right click to add a negative point of what not to segment. Press 't' to generate a candidate segmentation mask. Press 'r' to clear the current point annotation. Press 'e' to edit the existing annotation points."
58+ )
13259
133- renderer (images [0 ])
134- renderer .ax .axis ("off" )
135- cid = None
136- print (
137- f"Starting annotation process for { len (images )} images. Press 't' to finish annotation, 'r' to redo annotation. Press 'h' for help."
60+ def mouse_callback (event , x , y , flags , param ):
61+ nonlocal clicked_points
62+ if event == cv2 .EVENT_LBUTTONDOWN :
63+ clicked_points .append ((x , y , 1 ))
64+ elif event == cv2 .EVENT_RBUTTONDOWN :
65+ clicked_points .append ((x , y , - 1 ))
66+
67+ # Display the image and set mouse callback
68+ annotation_window_name = "Annotation: Click for positive points, right click for negative points. 'r' to reset, 'e' to edit, 't' to generate the segmentation"
69+ check_window_name = (
70+ "Check segmentation quality. Press 't' to proceed. Press 'e' to edit again."
13871 )
139- print ("--------------------------------" )
72+ cv2 .namedWindow (annotation_window_name , cv2 .WINDOW_GUI_NORMAL )
73+ cv2 .setMouseCallback (annotation_window_name , mouse_callback )
74+
14075 print_help_message ()
76+
77+ point_size = int (0.01 * (images [0 ].shape [0 ] + images [0 ].shape [1 ]) / 2 )
14178 while current_image_idx < len (images ):
142- image = images [current_image_idx ].copy ()
143- key = renderer .last_event .key if renderer .last_event is not None else None
144- if renderer .last_event is not None :
145- renderer .last_event = None
146- if key == "q" :
147- renderer .close ()
148- return None
149- if key == "h" :
150- print_help_message ()
79+ display_img = images [current_image_idx ].copy ()
80+ image = display_img .copy ()
81+ key = cv2 .waitKey (1 )
15182 if state == "annotation" :
152- cid = renderer .fig .canvas .mpl_connect ("button_press_event" , onclick )
153- renderer .ax .set_title (
154- "Click on the image to record annotation points for segmentation"
155- )
156- renderer (image )
83+ if clicked_points :
84+ for x , y , label in clicked_points :
85+ cv2 .circle (
86+ display_img ,
87+ (x , y ),
88+ point_size ,
89+ (25 , 200 , 25 ) if label == 1 else (200 , 25 , 25 ),
90+ - 1 ,
91+ )
92+ if key == ord ("r" ):
93+ print ("(r)esetting the point annotations" )
94+ clicked_points = []
95+ elif key == ord ("e" ):
96+ print ("Entering (e)dit mode" )
97+ elif key == ord ("t" ):
98+ if len (clicked_points ) == 0 :
99+ print (
100+ "No points to generate the segmentation mask. Make sure to add at least one point."
101+ )
102+ continue
103+ print (
104+ "Generating the segmentation mask, check its quality. If the mask is good press 't' again to move on."
105+ )
106+ cv2 .setWindowTitle (annotation_window_name , check_window_name )
107+ state = "check"
157108
158- if key == "t" :
159109 if self .segmentation_model == "sam2" :
160110 clicked_points_np = np .array (clicked_points )
161111 input_label = clicked_points_np [:, 2 ]
@@ -169,46 +119,31 @@ def clear_drawn_points():
169119 )
170120 mask = mask [0 ]
171121 state = "segmentation"
172- clear_drawn_points ()
173- elif key == "r" :
174- clear_drawn_points ()
175- clicked_points = []
176- print ("Cleared previous points" )
177122 elif state == "segmentation" :
178- renderer .fig .canvas .mpl_disconnect (cid )
179- masked_image = image .copy ()
180123 mask_color = np .array ([30 , 144 , 255 ])
181124 mask_overlay = mask .astype (float ).reshape (
182125 image .shape [0 ], image .shape [1 ], 1
183126 ) * mask_color .reshape (1 , 1 , - 1 )
184- masked_image = mask_overlay * 0.6 + masked_image * 0.4
185- masked_image [mask == 0 ] = image [mask == 0 ]
186- renderer .ax .set_title ("Check the segmentation quality" )
187- renderer (masked_image .astype (np .uint8 ))
188- if key == "t" :
127+ display_img = mask_overlay * 0.6 + display_img * 0.4
128+ display_img [mask == 0 ] = image [mask == 0 ]
129+ display_img = display_img .astype (np .uint8 )
130+ if key == ord ("t" ):
189131 masks .append (mask )
190132 current_image_idx += 1
191133 state = "annotation"
192- clear_drawn_points ()
193134 clicked_points = []
194135 if current_image_idx < len (images ):
195136 print_help_message ()
196- elif key == "e" :
197- state = "annotation"
198- # redraw existing points since they got removed to show the segmentation image
199- for x in annotation_objs :
200- x .remove ()
201- annotation_objs = []
202- for pos in clicked_points :
203- annotation_objs .append (
204- renderer .ax .plot (
205- pos [0 ], pos [1 ], "ro" if pos [2 ] == 0 else "go"
206- )[0 ]
207- )
208- elif key == "r" :
137+ elif key == ord ("e" ):
138+ print ("Entering (e)dit mode" )
139+ cv2 .setWindowTitle (annotation_window_name , annotation_window_name )
209140 state = "annotation"
141+ elif key == ord ("r" ):
142+ print ("(r)esetting the point annotations" )
210143 clicked_points = []
211- clear_drawn_points ()
212- print ("Cleared previous points" )
213- renderer .close ()
144+ state = "annotation"
145+ cv2 .imshow (
146+ annotation_window_name , cv2 .cvtColor (display_img , cv2 .COLOR_RGB2BGR )
147+ )
148+ cv2 .destroyWindow (annotation_window_name )
214149 return np .stack (masks )
0 commit comments