1616
1717from AILab_ImageMaskTools import pil2tensor , tensor2pil
1818
19- # SAM model definitions (6 models)
2019SAM_MODELS = {
2120 "sam_vit_h (2.56GB)" : {
2221 "model_url" : "https://huggingface.co/1038lab/sam/resolve/main/sam_vit_h.pth" ,
5049 }
5150}
5251
53- # GroundingDINO model definitions (2 models)
5452DINO_MODELS = {
5553 "GroundingDINO_SwinT_OGC (694MB)" : {
5654 "config_url" : "https://huggingface.co/1038lab/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py" ,
@@ -97,7 +95,6 @@ def apply_background_color(image: Image.Image, mask_image: Image.Image,
9795 background_color : str = "#222222" ) -> Image .Image :
9896 rgba_image = image .copy ().convert ('RGBA' )
9997 rgba_image .putalpha (mask_image .convert ('L' ))
100-
10198 if background == "Color" :
10299 def hex_to_rgba (hex_color ):
103100 hex_color = hex_color .lstrip ('#' )
@@ -172,25 +169,18 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
172169 background_color = "#222222" , invert_output = False ):
173170 device = "cuda" if torch .cuda .is_available () else "cpu"
174171
175- # 处理批量图像
176172 batch_size = image .shape [0 ] if len (image .shape ) == 4 else 1
177173 if len (image .shape ) == 3 :
178174 image = image .unsqueeze (0 )
179-
180175 result_images = []
181176 result_masks = []
182177 result_mask_images = []
183-
184178 for b in range (batch_size ):
185179 img_pil = tensor2pil (image [b ])
186180 img_np = np .array (img_pil .convert ("RGB" ))
187-
188- # Load GroundingDINO config and weights
189181 dino_info = DINO_MODELS [dino_model ]
190182 config_path = get_or_download_model_file (dino_info ["config_filename" ], dino_info ["config_url" ], "grounding-dino" )
191183 weights_path = get_or_download_model_file (dino_info ["model_filename" ], dino_info ["model_url" ], "grounding-dino" )
192-
193- # Load and cache GroundingDINO model
194184 dino_key = (config_path , weights_path , device )
195185 if dino_key not in self .dino_model_cache :
196186 args = SLConfig .fromfile (config_path )
@@ -201,12 +191,8 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
201191 model .to (device )
202192 self .dino_model_cache [dino_key ] = model
203193 dino = self .dino_model_cache [dino_key ]
204-
205- # Download/check SAM weights
206194 sam_info = SAM_MODELS [sam_model ]
207195 sam_ckpt_path = get_or_download_model_file (sam_info ["filename" ], sam_info ["model_url" ], "SAM" )
208-
209- # Load SAM model (cache to avoid reloading)
210196 sam_key = (sam_info ["model_type" ], sam_ckpt_path , device )
211197 if sam_key not in self .sam_model_cache :
212198 try :
@@ -230,8 +216,6 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
230216 else :
231217 raise e
232218 predictor = self .sam_model_cache [sam_key ]
233-
234- # Preprocess image for DINO
235219 from groundingdino .datasets .transforms import Compose , RandomResize , ToTensor , Normalize
236220 transform = Compose ([
237221 RandomResize ([800 ], max_size = 1333 ),
@@ -240,21 +224,13 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
240224 ])
241225 image_tensor , _ = transform (img_pil .convert ("RGB" ), None )
242226 image_tensor = image_tensor .unsqueeze (0 ).to (device )
243-
244- # Prepare text prompt
245227 text_prompt = prompt if prompt .endswith ("." ) else prompt + "."
246-
247- # Forward pass
248228 with torch .no_grad ():
249229 outputs = dino (image_tensor , captions = [text_prompt ])
250230 logits = outputs ["pred_logits" ].sigmoid ()[0 ]
251231 boxes = outputs ["pred_boxes" ][0 ]
252-
253- # Filter boxes by threshold
254232 filt_mask = logits .max (dim = 1 )[0 ] > threshold
255233 boxes_filt = boxes [filt_mask ]
256-
257- # Handle case with no detected boxes
258234 if boxes_filt .shape [0 ] == 0 :
259235 width , height = img_pil .size
260236 empty_mask = torch .zeros ((1 , height , width ), dtype = torch .float32 , device = "cpu" )
@@ -264,14 +240,10 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
264240 result_masks .append (empty_mask )
265241 result_mask_images .append (empty_mask_rgb )
266242 continue
267-
268- # Convert boxes to xyxy
269243 H , W = img_pil .size [1 ], img_pil .size [0 ]
270244 boxes_xyxy = box_ops .box_cxcywh_to_xyxy (boxes_filt )
271245 boxes_xyxy = boxes_xyxy * torch .tensor ([W , H , W , H ], dtype = torch .float32 , device = boxes_xyxy .device )
272246 boxes_xyxy = boxes_xyxy .cpu ().numpy ()
273-
274- # Use SAM to get masks for each box
275247 predictor .set_image (img_np )
276248 boxes_tensor = torch .tensor (boxes_xyxy , dtype = torch .float32 , device = predictor .device )
277249 transformed_boxes = predictor .transform .apply_boxes_torch (boxes_tensor , img_np .shape [:2 ])
@@ -281,38 +253,27 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
281253 boxes = transformed_boxes ,
282254 multimask_output = False
283255 )
284-
285- # Combine all masks into one
286- combined_mask = torch .max (masks , dim = 0 )[0 ] # Take maximum across all masks
256+ combined_mask = torch .max (masks , dim = 0 )[0 ]
287257 mask = combined_mask .float ().cpu ().numpy ()
288258 mask = mask .squeeze (0 )
289259 mask = (mask * 255 ).astype (np .uint8 )
290260 mask_pil = Image .fromarray (mask , mode = "L" )
291-
292- # Process mask and apply background
293261 mask_image = process_mask (mask_pil , invert_output , mask_blur , mask_offset )
294262 result_image = apply_background_color (img_pil , mask_image , background , background_color )
295263 if background == "Color" :
296264 result_image = result_image .convert ("RGB" )
297265 else :
298266 result_image = result_image .convert ("RGBA" )
299-
300- # Convert to tensors
301267 mask_tensor = torch .from_numpy (np .array (mask_image ).astype (np .float32 ) / 255.0 ).unsqueeze (0 )
302268 mask_image_vis = mask_tensor .reshape ((- 1 , 1 , mask_image .height , mask_image .width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
303-
304269 result_images .append (pil2tensor (result_image ))
305270 result_masks .append (mask_tensor )
306271 result_mask_images .append (mask_image_vis )
307-
308- # 如果没有成功处理任何图像,返回空结果
309272 if len (result_images ) == 0 :
310273 width , height = tensor2pil (image [0 ]).size
311274 empty_mask = torch .zeros ((batch_size , 1 , height , width ), dtype = torch .float32 , device = "cpu" )
312275 empty_mask_rgb = empty_mask .reshape ((- 1 , 1 , height , width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
313276 return (image , empty_mask , empty_mask_rgb )
314-
315- # 合并所有批次的结果
316277 return (torch .cat (result_images , dim = 0 ),
317278 torch .cat (result_masks , dim = 0 ),
318279 torch .cat (result_mask_images , dim = 0 ))
@@ -324,4 +285,3 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
324285NODE_DISPLAY_NAME_MAPPINGS = {
325286 "SegmentV2" : "Segmentation V2 (RMBG)" ,
326287}
327-
0 commit comments