66from tqdm import tqdm
77
88from metaseg import SamAutomaticMaskGenerator , SamPredictor , sam_model_registry
9- from metaseg .utils import download_model , load_image , load_video
9+ from metaseg .utils import download_model , load_box , load_image , load_mask , load_video , multi_boxes
1010
1111
1212class SegAutoMaskPredictor :
1313 def __init__ (self ):
1414 self .model = None
1515 self .device = "cuda" if torch .cuda .is_available () else "cpu"
16+ self .save = False
17+ self .show = False
1618
1719 def load_model (self , model_type ):
1820 if self .model is None :
@@ -22,24 +24,17 @@ def load_model(self, model_type):
2224
2325 return self .model
2426
25- def predict (self , frame , model_type , points_per_side , points_per_batch , min_area ):
27+ def image_predict (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.png" ):
28+ read_image = load_image (source )
2629 model = self .load_model (model_type )
2730 mask_generator = SamAutomaticMaskGenerator (
2831 model , points_per_side = points_per_side , points_per_batch = points_per_batch , min_mask_region_area = min_area
2932 )
3033
31- masks = mask_generator .generate (frame )
32-
33- return frame , masks
34+ masks = mask_generator .generate (read_image )
3435
35- def save_image (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.png" ):
36- read_image = load_image (source )
37- image , anns = self .predict (read_image , model_type , points_per_side , points_per_batch , min_area )
38- if len (anns ) == 0 :
39- return
40-
41- sorted_anns = sorted (anns , key = (lambda x : x ["area" ]), reverse = True )
42- mask_image = np .zeros ((anns [0 ]["segmentation" ].shape [0 ], anns [0 ]["segmentation" ].shape [1 ], 3 ), dtype = np .uint8 )
36+ sorted_anns = sorted (masks , key = (lambda x : x ["area" ]), reverse = True )
37+ mask_image = np .zeros ((masks [0 ]["segmentation" ].shape [0 ], masks [0 ]["segmentation" ].shape [1 ], 3 ), dtype = np .uint8 )
4338 colors = np .random .randint (0 , 255 , size = (256 , 3 ), dtype = np .uint8 )
4439 for i , ann in enumerate (sorted_anns ):
4540 m = ann ["segmentation" ]
@@ -53,12 +48,18 @@ def save_image(self, source, model_type, points_per_side, points_per_batch, min_
5348 img = cv2 .addWeighted (img , 0.35 , np .zeros_like (img ), 0.65 , 0 )
5449 mask_image = cv2 .add (mask_image , img )
5550
56- combined_mask = cv2 .add (image , mask_image )
57- cv2 .imwrite (output_path , combined_mask )
51+ combined_mask = cv2 .add (read_image , mask_image )
52+ if self .save :
53+ cv2 .imwrite (output_path , combined_mask )
54+
55+ if self .show :
56+ cv2 .imshow ("Output" , combined_mask )
57+ cv2 .waitKey (0 )
58+ cv2 .destroyAllWindows ()
5859
5960 return output_path
6061
61- def save_video (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.mp4" ):
62+ def video_predict (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.mp4" ):
6263 cap , out = load_video (source , output_path )
6364 length = int (cap .get (cv2 .CAP_PROP_FRAME_COUNT ))
6465 colors = np .random .randint (0 , 255 , size = (256 , 3 ), dtype = np .uint8 )
@@ -68,18 +69,23 @@ def save_video(self, source, model_type, points_per_side, points_per_batch, min_
6869 if not ret :
6970 break
7071
71- image , anns = self .predict (frame , model_type , points_per_side , points_per_batch , min_area )
72- if len (anns ) == 0 :
72+ model = self .load_model (model_type )
73+ mask_generator = SamAutomaticMaskGenerator (
74+ model , points_per_side = points_per_side , points_per_batch = points_per_batch , min_mask_region_area = min_area
75+ )
76+ masks = mask_generator .generate (frame )
77+
78+ if len (masks ) == 0 :
7379 continue
7480
75- sorted_anns = sorted (anns , key = (lambda x : x ["area" ]), reverse = True )
81+ sorted_anns = sorted (masks , key = (lambda x : x ["area" ]), reverse = True )
7682 mask_image = np .zeros (
77- (anns [0 ]["segmentation" ].shape [0 ], anns [0 ]["segmentation" ].shape [1 ], 3 ), dtype = np .uint8
83+ (masks [0 ]["segmentation" ].shape [0 ], masks [0 ]["segmentation" ].shape [1 ], 3 ), dtype = np .uint8
7884 )
7985
8086 for i , ann in enumerate (sorted_anns ):
8187 m = ann ["segmentation" ]
82- color = colors [i % 256 ] # Her nesne için farklı bir renk kullan
88+ color = colors [i % 256 ]
8389 img = np .zeros ((m .shape [0 ], m .shape [1 ], 3 ), dtype = np .uint8 )
8490 img [:, :, 0 ] = color [0 ]
8591 img [:, :, 1 ] = color [1 ]
@@ -102,6 +108,8 @@ class SegManualMaskPredictor:
102108 def __init__ (self ):
103109 self .model = None
104110 self .device = "cuda" if torch .cuda .is_available () else "cpu"
111+ self .save = False
112+ self .show = False
105113
106114 def load_model (self , model_type ):
107115 if self .model is None :
@@ -111,49 +119,35 @@ def load_model(self, model_type):
111119
112120 return self .model
113121
114- def load_mask (self , mask , random_color ):
115- if random_color :
116- color = np .random .rand (3 ) * 255
117- else :
118- color = np .array ([100 , 50 , 0 ])
119-
120- h , w = mask .shape [- 2 :]
121- mask_image = mask .reshape (h , w , 1 ) * color .reshape (1 , 1 , - 1 )
122- mask_image = mask_image .astype (np .uint8 )
123- return mask_image
124-
125- def load_box (self , box , image ):
126- x , y , w , h = int (box [0 ]), int (box [1 ]), int (box [2 ]), int (box [3 ])
127- cv2 .rectangle (image , (x , y ), (w , h ), (0 , 255 , 0 ), 2 )
128- return image
129-
130- def multi_boxes (self , boxes , predictor , image ):
131- input_boxes = torch .tensor (boxes , device = predictor .device )
132- transformed_boxes = predictor .transform .apply_boxes_torch (input_boxes , image .shape [:2 ])
133- return input_boxes , transformed_boxes
134-
135- def predict (
122+ def image_predict (
136123 self ,
137- frame ,
124+ source ,
138125 model_type ,
139126 input_box = None ,
140127 input_point = None ,
141128 input_label = None ,
142129 multimask_output = False ,
130+ output_path = "output.png" ,
143131 ):
132+ image = load_image (source )
144133 model = self .load_model (model_type )
145134 predictor = SamPredictor (model )
146- predictor .set_image (frame )
135+ predictor .set_image (image )
147136
148137 if type (input_box [0 ]) == list :
149- input_boxes , new_boxes = self . multi_boxes (input_box , predictor , frame )
138+ input_boxes , new_boxes = multi_boxes (input_box , predictor , image )
150139
151140 masks , _ , _ = predictor .predict_torch (
152141 point_coords = None ,
153142 point_labels = None ,
154143 boxes = new_boxes ,
155144 multimask_output = False ,
156145 )
146+ for mask in masks :
147+ mask_image = load_mask (mask .cpu ().numpy (), False )
148+
149+ for box in input_boxes :
150+ image = load_box (box .cpu ().numpy (), image )
157151
158152 elif type (input_box [0 ]) == int :
159153 input_boxes = np .array (input_box )[None , :]
@@ -164,36 +158,16 @@ def predict(
164158 box = input_boxes ,
165159 multimask_output = multimask_output ,
166160 )
167-
168- return frame , masks , input_boxes
169-
170- def save_image (
171- self ,
172- source ,
173- model_type ,
174- input_box = None ,
175- input_point = None ,
176- input_label = None ,
177- multimask_output = False ,
178- output_path = "output.png" ,
179- ):
180- read_image = load_image (source )
181- image , anns , boxes = self .predict (read_image , model_type , input_box , input_point , input_label , multimask_output )
182- if len (anns ) == 0 :
183- return
184-
185- if type (input_box [0 ]) == list :
186- for mask in anns :
187- mask_image = self .load_mask (mask .cpu ().numpy (), False )
188-
189- for box in boxes :
190- image = self .load_box (box .cpu ().numpy (), image )
191-
192- elif type (input_box [0 ]) == int :
193- mask_image = self .load_mask (anns , True )
194- image = self .load_box (input_box , image )
161+ mask_image = load_mask (masks , True )
162+ image = load_box (input_box , image )
195163
196164 combined_mask = cv2 .add (image , mask_image )
197- cv2 .imwrite (output_path , combined_mask )
165+ if self .save :
166+ cv2 .imwrite (output_path , combined_mask )
167+
168+ if self .show :
169+ cv2 .imshow ("Output" , combined_mask )
170+ cv2 .waitKey (0 )
171+ cv2 .destroyAllWindows ()
198172
199173 return output_path
0 commit comments