66from tqdm import tqdm
77
88from metaseg import SamAutomaticMaskGenerator , SamPredictor , sam_model_registry
9- from metaseg .utils import download_model , load_box , load_image , load_mask , load_video , multi_boxes
9+ from metaseg .utils import download_model , load_box , load_image , load_mask , load_video , multi_boxes , show_image , save_image
1010
1111
1212class SegAutoMaskPredictor :
1313 def __init__ (self ):
1414 self .model = None
1515 self .device = "cuda" if torch .cuda .is_available () else "cpu"
16- self .save = False
17- self .show = False
1816
1917 def load_model (self , model_type ):
2018 if self .model is None :
@@ -24,7 +22,7 @@ def load_model(self, model_type):
2422
2523 return self .model
2624
27- def image_predict (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.png" ):
25+ def image_predict (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.png" , show = False , save = False ):
2826 read_image = load_image (source )
2927 model = self .load_model (model_type )
3028 mask_generator = SamAutomaticMaskGenerator (
@@ -49,15 +47,15 @@ def image_predict(self, source, model_type, points_per_side, points_per_batch, m
4947 mask_image = cv2 .add (mask_image , img )
5048
5149 combined_mask = cv2 .add (read_image , mask_image )
52- if self .save :
53- cv2 . imwrite ( output_path , combined_mask )
54-
55- if self . show :
56- cv2 . imshow ( "Output" , combined_mask )
57- cv2 . waitKey ( 0 )
58- cv2 . destroyAllWindows ()
59-
60- return output_path
50+ self .combined_mask = combined_mask
51+ if show :
52+ show_image ( combined_mask )
53+
54+ if save :
55+ save_image ( output_path = output_path , image = combined_mask )
56+
57+ return masks
58+
6159
6260 def video_predict (self , source , model_type , points_per_side , points_per_batch , min_area , output_path = "output.mp4" ):
6361 cap , out = load_video (source , output_path )
@@ -128,6 +126,9 @@ def image_predict(
128126 input_label = None ,
129127 multimask_output = False ,
130128 output_path = "output.png" ,
129+ random_color = False ,
130+ show = False ,
131+ save = False ,
131132 ):
132133 image = load_image (source )
133134 model = self .load_model (model_type )
@@ -144,7 +145,7 @@ def image_predict(
144145 multimask_output = False ,
145146 )
146147 for mask in masks :
147- mask_image = load_mask (mask .cpu ().numpy (), False )
148+ mask_image = load_mask (mask .cpu ().numpy (), random_color )
148149
149150 for box in input_boxes :
150151 image = load_box (box .cpu ().numpy (), image )
@@ -158,16 +159,14 @@ def image_predict(
158159 box = input_boxes ,
159160 multimask_output = multimask_output ,
160161 )
161- mask_image = load_mask (masks , True )
162+ mask_image = load_mask (masks , random_color )
162163 image = load_box (input_box , image )
163164
164165 combined_mask = cv2 .add (image , mask_image )
165- if self . save :
166- cv2 . imwrite (output_path , combined_mask )
166+ if save :
167+ save_image (output_path = output_path , image = combined_mask )
167168
168- if self .show :
169- cv2 .imshow ("Output" , combined_mask )
170- cv2 .waitKey (0 )
171- cv2 .destroyAllWindows ()
169+ if show :
170+ show_image (combined_mask )
172171
173- return output_path
172+ return masks
0 commit comments