1- # ComfyUI-RMBG
1+ # ComfyUI-RMBG V2.9.1
22# This custom node for ComfyUI provides functionality for background removal using various models,
33# including RMBG-2.0, INSPYRENET, BEN, BEN2 and BIREFNET-HR. It leverages deep learning techniques
44# to process images and generate masks for background removal.
@@ -572,49 +572,51 @@ def process_image(self, image, model, **params):
572572 handle_model_error (download_message )
573573 print ("Model files downloaded successfully" )
574574
575- for img in image :
576- mask = model_instance . process_image ( img , model , params )
577-
575+ model_type = AVAILABLE_MODELS [ model ][ "type" ]
576+
577+ def _process_pair ( img , mask ):
578578 if isinstance (mask , list ):
579579 masks = [m .convert ("L" ) for m in mask if isinstance (m , Image .Image )]
580- mask = masks [0 ] if masks else None
580+ mask_local = masks [0 ] if masks else None
581581 elif isinstance (mask , Image .Image ):
582- mask = mask .convert ("L" )
583-
584- mask_tensor = pil2tensor (mask )
585- mask_tensor = mask_tensor * (1 + (1 - params ["sensitivity" ]))
586- mask_tensor = torch .clamp (mask_tensor , 0 , 1 )
587- mask = tensor2pil (mask_tensor )
582+ mask_local = mask .convert ("L" )
583+ else :
584+ mask_local = mask
585+
586+ mask_tensor_local = pil2tensor (mask_local )
587+ mask_tensor_local = mask_tensor_local * (1 + (1 - params ["sensitivity" ]))
588+ mask_tensor_local = torch .clamp (mask_tensor_local , 0 , 1 )
589+ mask_img_local = tensor2pil (mask_tensor_local )
588590
589591 if params ["mask_blur" ] > 0 :
590- mask = mask .filter (ImageFilter .GaussianBlur (radius = params ["mask_blur" ]))
592+ mask_img_local = mask_img_local .filter (ImageFilter .GaussianBlur (radius = params ["mask_blur" ]))
591593
592594 if params ["mask_offset" ] != 0 :
593595 if params ["mask_offset" ] > 0 :
594596 for _ in range (params ["mask_offset" ]):
595- mask = mask .filter (ImageFilter .MaxFilter (3 ))
597+ mask_img_local = mask_img_local .filter (ImageFilter .MaxFilter (3 ))
596598 else :
597599 for _ in range (- params ["mask_offset" ]):
598- mask = mask .filter (ImageFilter .MinFilter (3 ))
600+ mask_img_local = mask_img_local .filter (ImageFilter .MinFilter (3 ))
599601
600602 if params ["invert_output" ]:
601- mask = Image .fromarray (255 - np .array (mask ))
602-
603- img_tensor = torch .from_numpy (np .array (tensor2pil (img ))).permute (2 , 0 , 1 ).unsqueeze (0 ) / 255.0
604- mask_tensor = torch .from_numpy (np .array (mask )).unsqueeze (0 ).unsqueeze (0 ) / 255.0
605-
606- orig_image = tensor2pil (img )
603+ mask_img_local = Image .fromarray (255 - np .array (mask_img_local ))
604+
605+ img_tensor_local = torch .from_numpy (np .array (tensor2pil (img ))).permute (2 , 0 , 1 ).unsqueeze (0 ) / 255.0
606+ mask_tensor_b1hw = torch .from_numpy (np .array (mask_img_local )).unsqueeze (0 ).unsqueeze (0 ) / 255.0
607+
608+ orig_image_local = tensor2pil (img )
607609
608610 if params .get ("refine_foreground" , False ):
609- refined_fg = refine_foreground (img_tensor , mask_tensor )
610- refined_fg = tensor2pil (refined_fg [0 ].permute (1 , 2 , 0 ))
611- r , g , b = refined_fg .split ()
612- foreground = Image .merge ('RGBA' , (r , g , b , mask ))
611+ refined_fg_local = refine_foreground (img_tensor_local , mask_tensor_b1hw )
612+ refined_fg_local = tensor2pil (refined_fg_local [0 ].permute (1 , 2 , 0 ))
613+ r , g , b = refined_fg_local .split ()
614+ foreground_local = Image .merge ('RGBA' , (r , g , b , mask_img_local ))
613615 else :
614- orig_rgba = orig_image .convert ("RGBA" )
615- r , g , b , _ = orig_rgba .split ()
616- foreground = Image .merge ('RGBA' , (r , g , b , mask ))
617-
616+ orig_rgba_local = orig_image_local .convert ("RGBA" )
617+ r , g , b , _ = orig_rgba_local .split ()
618+ foreground_local = Image .merge ('RGBA' , (r , g , b , mask_img_local ))
619+
618620 if params ["background" ] == "Color" :
619621 def hex_to_rgba (hex_color ):
620622 hex_color = hex_color .lstrip ('#' )
@@ -628,14 +630,29 @@ def hex_to_rgba(hex_color):
628630 return (r , g , b , a )
629631 background_color = params .get ("background_color" , "#222222" )
630632 rgba = hex_to_rgba (background_color )
631- bg_image = Image .new ('RGBA' , orig_image .size , rgba )
632- composite_image = Image .alpha_composite (bg_image , foreground )
633+ bg_image = Image .new ('RGBA' , orig_image_local .size , rgba )
634+ composite_image = Image .alpha_composite (bg_image , foreground_local )
633635 processed_images .append (pil2tensor (composite_image .convert ("RGB" )))
634636 else :
635- processed_images .append (pil2tensor (foreground ))
637+ processed_images .append (pil2tensor (foreground_local ))
636638
637- processed_masks .append (pil2tensor (mask ))
638-
639+ processed_masks .append (pil2tensor (mask_img_local ))
640+
641+ if model_type in ("rmbg" , "ben2" ):
642+ images_list = [img for img in image ]
643+ chunk_size = 4
644+ for start in range (0 , len (images_list ), chunk_size ):
645+ batch_imgs = images_list [start :start + chunk_size ]
646+ masks = model_instance .process_image (batch_imgs , model , params )
647+ if isinstance (masks , Image .Image ):
648+ masks = [masks ]
649+ for img_item , mask_item in zip (batch_imgs , masks ):
650+ _process_pair (img_item , mask_item )
651+ else :
652+ for img in image :
653+ mask = model_instance .process_image (img , model , params )
654+ _process_pair (img , mask )
655+
639656 mask_images = []
640657 for mask_tensor in processed_masks :
641658 mask_image = mask_tensor .reshape ((- 1 , 1 , mask_tensor .shape [- 2 ], mask_tensor .shape [- 1 ])).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
0 commit comments