@@ -97,11 +97,14 @@ def apply_background_color(image: Image.Image, mask_image: Image.Image,
9797 background_color : str = "#222222" ) -> Image .Image :
9898 rgba_image = image .copy ().convert ('RGBA' )
9999 rgba_image .putalpha (mask_image .convert ('L' ))
100+
100101 if background == "Color" :
101102 def hex_to_rgba (hex_color ):
102103 hex_color = hex_color .lstrip ('#' )
103104 r , g , b = int (hex_color [0 :2 ], 16 ), int (hex_color [2 :4 ], 16 ), int (hex_color [4 :6 ], 16 )
104105 return (r , g , b , 255 )
106+ params = {"background_color" : background_color }
107+ background_color = params .get ("background_color" , "#222222" )
105108 rgba = hex_to_rgba (background_color )
106109 bg_image = Image .new ('RGBA' , image .size , rgba )
107110 composite_image = Image .alpha_composite (bg_image , rgba_image )
@@ -146,7 +149,7 @@ def INPUT_TYPES(cls):
146149 "dino_model" : (list (DINO_MODELS .keys ()),),
147150 },
148151 "optional" : {
149- "threshold" : ("FLOAT" , {"default" : 0.30 , "min" : 0.05 , "max" : 0.95 , "step" : 0.01 , "tooltip" : tooltips ["threshold" ]}),
152+ "threshold" : ("FLOAT" , {"default" : 0.35 , "min" : 0.05 , "max" : 0.95 , "step" : 0.01 , "tooltip" : tooltips ["threshold" ]}),
150153 "mask_blur" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 64 , "step" : 1 , "tooltip" : tooltips ["mask_blur" ]}),
151154 "mask_offset" : ("INT" , {"default" : 0 , "min" : - 64 , "max" : 64 , "step" : 1 , "tooltip" : tooltips ["mask_offset" ]}),
152155 "invert_output" : ("BOOLEAN" , {"default" : False , "tooltip" : tooltips ["invert_output" ]}),
@@ -167,105 +170,150 @@ def __init__(self):
167170 def segment_v2 (self , image , prompt , sam_model , dino_model , threshold = 0.30 ,
168171 mask_blur = 0 , mask_offset = 0 , background = "Alpha" ,
169172 background_color = "#222222" , invert_output = False ):
170- img_pil = tensor2pil (image [0 ]) if image .ndim == 4 else tensor2pil (image )
171- img_np = np .array (img_pil .convert ("RGB" ))
172173 device = "cuda" if torch .cuda .is_available () else "cpu"
173174
174- # Load GroundingDINO config and weights
175- dino_info = DINO_MODELS [dino_model ]
176- config_path = get_or_download_model_file (dino_info ["config_filename" ], dino_info ["config_url" ], "grounding-dino" )
177- weights_path = get_or_download_model_file (dino_info ["model_filename" ], dino_info ["model_url" ], "grounding-dino" )
175+ # 处理批量图像
176+ batch_size = image .shape [0 ] if len (image .shape ) == 4 else 1
177+ if len (image .shape ) == 3 :
178+ image = image .unsqueeze (0 )
179+
180+ result_images = []
181+ result_masks = []
182+ result_mask_images = []
183+
184+ for b in range (batch_size ):
185+ img_pil = tensor2pil (image [b ])
186+ img_np = np .array (img_pil .convert ("RGB" ))
178187
179- # Load and cache GroundingDINO model
180- dino_key = (config_path , weights_path , device )
181- if dino_key not in self .dino_model_cache :
182- args = SLConfig .fromfile (config_path )
183- model = build_model (args )
184- checkpoint = torch .load (weights_path , map_location = "cpu" )
185- model .load_state_dict (clean_state_dict (checkpoint ["model" ]), strict = False )
186- model .eval ()
187- model .to (device )
188- self .dino_model_cache [dino_key ] = model
189- dino = self .dino_model_cache [dino_key ]
188+ # Load GroundingDINO config and weights
189+ dino_info = DINO_MODELS [dino_model ]
190+ config_path = get_or_download_model_file (dino_info ["config_filename" ], dino_info ["config_url" ], "grounding-dino" )
191+ weights_path = get_or_download_model_file (dino_info ["model_filename" ], dino_info ["model_url" ], "grounding-dino" )
190192
191- # Preprocess image for DINO
192- from groundingdino .datasets .transforms import Compose , RandomResize , ToTensor , Normalize
193- transform = Compose ([
194- RandomResize ([800 ], max_size = 1333 ),
195- ToTensor (),
196- Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ]),
197- ])
198- image_tensor , _ = transform (img_pil .convert ("RGB" ), None )
199- image_tensor = image_tensor .unsqueeze (0 ).to (device )
193+ # Load and cache GroundingDINO model
194+ dino_key = (config_path , weights_path , device )
195+ if dino_key not in self .dino_model_cache :
196+ args = SLConfig .fromfile (config_path )
197+ model = build_model (args )
198+ checkpoint = torch .load (weights_path , map_location = "cpu" )
199+ model .load_state_dict (clean_state_dict (checkpoint ["model" ]), strict = False )
200+ model .eval ()
201+ model .to (device )
202+ self .dino_model_cache [dino_key ] = model
203+ dino = self .dino_model_cache [dino_key ]
200204
201- # Prepare text prompt
202- text_prompt = prompt if prompt .endswith ("." ) else prompt + "."
205+ # Download/check SAM weights
206+ sam_info = SAM_MODELS [sam_model ]
207+ sam_ckpt_path = get_or_download_model_file (sam_info ["filename" ], sam_info ["model_url" ], "SAM" )
203208
204- # Forward pass
205- with torch .no_grad ():
206- outputs = dino (image_tensor , captions = [text_prompt ])
207- logits = outputs ["pred_logits" ].sigmoid ()[0 ]
208- boxes = outputs ["pred_boxes" ][0 ]
209+ # Load SAM model (cache to avoid reloading)
210+ sam_key = (sam_info ["model_type" ], sam_ckpt_path , device )
211+ if sam_key not in self .sam_model_cache :
212+ try :
213+ sam = sam_model_registry [sam_info ["model_type" ]](checkpoint = sam_ckpt_path )
214+ sam .to (device )
215+ self .sam_model_cache [sam_key ] = SamPredictor (sam )
216+ except RuntimeError as e :
217+ if "Unexpected key(s) in state_dict" in str (e ):
218+ print ("Warning: SAM model loading issue detected, please try using SegmentV1 node instead" )
219+ print (f"Error details: { str (e )} " )
220+ width , height = img_pil .size
221+ empty_mask = torch .zeros ((1 , height , width ), dtype = torch .float32 , device = "cpu" )
222+ empty_mask_rgb = empty_mask .reshape ((- 1 , 1 , height , width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
223+ result_image = apply_background_color (img_pil , Image .fromarray ((empty_mask [0 ].numpy () * 255 ).astype (np .uint8 )), background , background_color )
224+ result_images .append (pil2tensor (result_image ))
225+ result_masks .append (empty_mask )
226+ result_mask_images .append (empty_mask_rgb )
227+ continue
228+ else :
229+ raise e
230+ predictor = self .sam_model_cache [sam_key ]
209231
210- # Filter boxes by threshold
211- filt_mask = logits .max (dim = 1 )[0 ] > threshold
212- boxes_filt = boxes [filt_mask ]
213- if boxes_filt .shape [0 ] == 0 :
214- width , height = img_pil .size
215- empty_mask = torch .zeros ((1 , height , width ), dtype = torch .float32 , device = "cpu" )
216- empty_mask_rgb = empty_mask .reshape ((- 1 , 1 , height , width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
217- result_image = apply_background_color (img_pil , Image .fromarray ((empty_mask [0 ].numpy () * 255 ).astype (np .uint8 )), background , background_color )
218- return (pil2tensor (result_image ), empty_mask , empty_mask_rgb )
232+ # Preprocess image for DINO
233+ from groundingdino .datasets .transforms import Compose , RandomResize , ToTensor , Normalize
234+ transform = Compose ([
235+ RandomResize ([800 ], max_size = 1333 ),
236+ ToTensor (),
237+ Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ]),
238+ ])
239+ image_tensor , _ = transform (img_pil .convert ("RGB" ), None )
240+ image_tensor = image_tensor .unsqueeze (0 ).to (device )
241+
242+ # Prepare text prompt
243+ text_prompt = prompt if prompt .endswith ("." ) else prompt + "."
244+
245+ # Forward pass
246+ with torch .no_grad ():
247+ outputs = dino (image_tensor , captions = [text_prompt ])
248+ logits = outputs ["pred_logits" ].sigmoid ()[0 ]
249+ boxes = outputs ["pred_boxes" ][0 ]
219250
220- # Convert boxes to xyxy
221- H , W = img_pil .size [1 ], img_pil .size [0 ]
222- boxes_xyxy = box_ops .box_cxcywh_to_xyxy (boxes_filt )
223- boxes_xyxy = boxes_xyxy * torch .tensor ([W , H , W , H ], dtype = torch .float32 , device = boxes_xyxy .device )
224- boxes_xyxy = boxes_xyxy .cpu ().numpy ()
251+ # Filter boxes by threshold
252+ filt_mask = logits .max (dim = 1 )[0 ] > threshold
253+ boxes_filt = boxes [filt_mask ]
254+
255+ # Handle case with no detected boxes
256+ if boxes_filt .shape [0 ] == 0 :
257+ width , height = img_pil .size
258+ empty_mask = torch .zeros ((1 , height , width ), dtype = torch .float32 , device = "cpu" )
259+ empty_mask_rgb = empty_mask .reshape ((- 1 , 1 , height , width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
260+ result_image = apply_background_color (img_pil , Image .fromarray ((empty_mask [0 ].numpy () * 255 ).astype (np .uint8 )), background , background_color )
261+ result_images .append (pil2tensor (result_image ))
262+ result_masks .append (empty_mask )
263+ result_mask_images .append (empty_mask_rgb )
264+ continue
225265
226- # Download/check SAM weights
227- sam_info = SAM_MODELS [sam_model ]
228- sam_ckpt_path = get_or_download_model_file (sam_info ["filename" ], sam_info ["model_url" ], "SAM" )
266+ # Convert boxes to xyxy
267+ H , W = img_pil .size [1 ], img_pil .size [0 ]
268+ boxes_xyxy = box_ops .box_cxcywh_to_xyxy (boxes_filt )
269+ boxes_xyxy = boxes_xyxy * torch .tensor ([W , H , W , H ], dtype = torch .float32 , device = boxes_xyxy .device )
270+ boxes_xyxy = boxes_xyxy .cpu ().numpy ()
229271
230- # Load SAM model (cache to avoid reloading)
231- sam_key = (sam_info ["model_type" ], sam_ckpt_path , device )
232- if sam_key not in self .sam_model_cache :
233- sam = sam_model_registry [sam_info ["model_type" ]](checkpoint = sam_ckpt_path )
234- sam .to (device )
235- self .sam_model_cache [sam_key ] = SamPredictor (sam )
236- predictor = self .sam_model_cache [sam_key ]
272+ # Use SAM to get masks for each box
273+ predictor .set_image (img_np )
274+ boxes_tensor = torch .tensor (boxes_xyxy , dtype = torch .float32 , device = predictor .device )
275+ transformed_boxes = predictor .transform .apply_boxes_torch (boxes_tensor , img_np .shape [:2 ])
276+ masks , _ , _ = predictor .predict_torch (
277+ point_coords = None ,
278+ point_labels = None ,
279+ boxes = transformed_boxes ,
280+ multimask_output = False
281+ )
237282
238- # Use SAM to get masks for each box
239- predictor .set_image (img_np )
240- boxes_tensor = torch .tensor (boxes_xyxy , dtype = torch .float32 , device = predictor .device )
241- transformed_boxes = predictor .transform .apply_boxes_torch (boxes_tensor , img_np .shape [:2 ])
242- masks , _ , _ = predictor .predict_torch (
243- point_coords = None ,
244- point_labels = None ,
245- boxes = transformed_boxes ,
246- multimask_output = False
247- )
248- # Process mask following the original implementation
249- print (f"Mask shape before processing: { masks .shape } " )
250- # Combine all masks into one
251- combined_mask = torch .max (masks , dim = 0 )[0 ] # Take maximum across all masks
252- mask = combined_mask .float ().cpu ().numpy ()
253- print (f"Mask shape after processing: { mask .shape } " )
254- # Squeeze out the extra dimension to get a 2D array
255- mask = mask .squeeze (0 )
256- print (f"Final mask shape: { mask .shape } " )
257- mask = (mask * 255 ).astype (np .uint8 )
258- mask_pil = Image .fromarray (mask , mode = "L" )
283+ # Combine all masks into one
284+ combined_mask = torch .max (masks , dim = 0 )[0 ] # Take maximum across all masks
285+ mask = combined_mask .float ().cpu ().numpy ()
286+ mask = mask .squeeze (0 )
287+ mask = (mask * 255 ).astype (np .uint8 )
288+ mask_pil = Image .fromarray (mask , mode = "L" )
259289
260- mask_image = process_mask (mask_pil , invert_output , mask_blur , mask_offset )
261- result_image = apply_background_color (img_pil , mask_image , background , background_color )
262- if background == "Color" :
263- result_image = result_image .convert ("RGB" )
264- else :
265- result_image = result_image .convert ("RGBA" )
266- mask_tensor = torch .from_numpy (np .array (mask_image ).astype (np .float32 ) / 255.0 ).unsqueeze (0 )
267- mask_image_vis = mask_tensor .reshape ((- 1 , 1 , mask_image .height , mask_image .width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
268- return (pil2tensor (result_image ), mask_tensor , mask_image_vis )
290+ # Process mask and apply background
291+ mask_image = process_mask (mask_pil , invert_output , mask_blur , mask_offset )
292+ result_image = apply_background_color (img_pil , mask_image , background , background_color )
293+ if background == "Color" :
294+ result_image = result_image .convert ("RGB" )
295+ else :
296+ result_image = result_image .convert ("RGBA" )
297+
298+ # Convert to tensors
299+ mask_tensor = torch .from_numpy (np .array (mask_image ).astype (np .float32 ) / 255.0 ).unsqueeze (0 )
300+ mask_image_vis = mask_tensor .reshape ((- 1 , 1 , mask_image .height , mask_image .width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
301+
302+ result_images .append (pil2tensor (result_image ))
303+ result_masks .append (mask_tensor )
304+ result_mask_images .append (mask_image_vis )
305+
306+ # 如果没有成功处理任何图像,返回空结果
307+ if len (result_images ) == 0 :
308+ width , height = tensor2pil (image [0 ]).size
309+ empty_mask = torch .zeros ((batch_size , 1 , height , width ), dtype = torch .float32 , device = "cpu" )
310+ empty_mask_rgb = empty_mask .reshape ((- 1 , 1 , height , width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
311+ return (image , empty_mask , empty_mask_rgb )
312+
313+ # 合并所有批次的结果
314+ return (torch .cat (result_images , dim = 0 ),
315+ torch .cat (result_masks , dim = 0 ),
316+ torch .cat (result_mask_images , dim = 0 ))
269317
270318NODE_CLASS_MAPPINGS = {
271319 "SegmentV2" : SegmentV2 ,
0 commit comments