1- # ComfyUI-RMBG v2.5 .0
1+ # ComfyUI-RMBG v2.6 .0
22#
33# This node facilitates background removal using various models, including RMBG-2.0, INSPYRENET, BEN, BEN2, and BIREFNET-HR.
44# It utilizes advanced deep learning techniques to process images and generate accurate masks for background removal.
3131#
3232# 5. Input Nodes:
3333# - ColorInput: A node for inputting colors in various formats.
34-
34+ #
35+ # License: GPL-3.0
3536# These nodes are crafted to streamline common image and mask operations within ComfyUI workflows.
3637
3738import os
@@ -587,6 +588,8 @@ def _resize_if_needed(self, mask, target_shape):
587588
588589# Image loader node
589590class AILab_LoadImage :
591+ upscale_methods = ["nearest-exact" , "bilinear" , "area" , "bicubic" , "lanczos" ]
592+
590593 @classmethod
591594 def INPUT_TYPES (cls ):
592595 input_dir = folder_paths .get_input_directory ()
@@ -596,6 +599,7 @@ def INPUT_TYPES(cls):
596599 "required" : {
597600 "image" : (sorted (files ) or ["" ], {"image_upload" : True }),
598601 "mask_channel" : (["alpha" , "red" , "green" , "blue" ], {"default" : "alpha" , "tooltip" : "Select channel to extract mask from" }),
602+ "upscale_method" : (cls .upscale_methods , {"default" : "lanczos" , "tooltip" : "Method used for resizing the image" }),
599603 "scale_by" : ("FLOAT" , {"default" : 1.0 , "min" : 0.01 , "max" : 8.0 , "step" : 0.01 , "tooltip" : "Scale image by this factor (ignored if size > 0)" }),
600604 "resize_mode" : (["longest_side" , "shortest_side" , "width" , "height" ], {"default" : "longest_side" , "tooltip" : "Choose how to resize the image" }),
601605 "size" : ("INT" , {"default" : 0 , "min" : 0 , "max" : MAX_RESOLUTION , "step" : 1 , "tooltip" : "Target size for the selected resize mode (0 = keep original size)" }),
@@ -611,14 +615,28 @@ def INPUT_TYPES(cls):
611615 FUNCTION = "load_image"
612616 OUTPUT_NODE = False
613617
614- def load_image (self , image , mask_channel = "alpha" , scale_by = 1.0 , resize_mode = "longest_side" , size = 0 , extra_pnginfo = None ):
618+ def load_image (self , image , mask_channel = "alpha" , upscale_method = "lanczos" , scale_by = 1.0 , resize_mode = "longest_side" , size = 0 , extra_pnginfo = None ):
615619 try :
616620 image_path = folder_paths .get_annotated_filepath (image )
617621 img = Image .open (image_path )
618622
619623 orig_width , orig_height = img .size
620624
621- # Image resizing logic
625+ resampling_map = {
626+ "nearest-exact" : Image .NEAREST ,
627+ "bilinear" : Image .BILINEAR ,
628+ "area" : Image .BOX ,
629+ "bicubic" : Image .BICUBIC ,
630+ "lanczos" : Image .LANCZOS
631+ }
632+ resampling = resampling_map .get (upscale_method , Image .LANCZOS )
633+
634+ has_alpha = 'A' in img .getbands ()
635+ if has_alpha and mask_channel == "alpha" :
636+ original_alpha = img .getchannel ('A' )
637+
638+ img_rgb = img .convert ('RGB' )
639+
622640 if size > 0 :
623641 if resize_mode == "longest_side" :
624642 if orig_width >= orig_height :
@@ -627,57 +645,66 @@ def load_image(self, image, mask_channel="alpha", scale_by=1.0, resize_mode="lon
627645 else :
628646 new_height = size
629647 new_width = int (orig_width * (size / orig_height ))
630- img = img .resize ((new_width , new_height ), Image . LANCZOS )
648+ img_rgb = img_rgb .resize ((new_width , new_height ), resampling )
631649 elif resize_mode == "shortest_side" :
632650 if orig_width <= orig_height :
633651 new_width = size
634652 new_height = int (orig_height * (size / orig_width ))
635653 else :
636654 new_height = size
637655 new_width = int (orig_width * (size / orig_height ))
638- img = img .resize ((new_width , new_height ), Image . LANCZOS )
656+ img_rgb = img_rgb .resize ((new_width , new_height ), resampling )
639657 elif resize_mode == "width" :
640658 new_width = size
641659 new_height = int (orig_height * (size / orig_width ))
642- img = img .resize ((new_width , new_height ), Image . LANCZOS )
660+ img_rgb = img_rgb .resize ((new_width , new_height ), resampling )
643661 elif resize_mode == "height" :
644662 new_height = size
645663 new_width = int (orig_width * (size / orig_height ))
646- img = img .resize ((new_width , new_height ), Image . LANCZOS )
664+ img_rgb = img_rgb .resize ((new_width , new_height ), resampling )
647665 elif scale_by != 1.0 :
648666 new_width = int (orig_width * scale_by )
649667 new_height = int (orig_height * scale_by )
650- img = img .resize ((new_width , new_height ), Image . LANCZOS )
668+ img_rgb = img_rgb .resize ((new_width , new_height ), resampling )
651669
652- width , height = img .size
670+ width , height = img_rgb .size
671+
672+ mask = None
673+ if mask_channel == "alpha" and has_alpha :
674+ if (size > 0 or scale_by != 1.0 ) and 'original_alpha' in locals ():
675+ mask_img = original_alpha .resize ((width , height ), resampling )
676+ mask = np .array (mask_img ).astype (np .float32 ) / 255.0
677+ mask = 1. - torch .from_numpy (mask )
653678
654679 output_images = []
655680 output_masks = []
656- for i in ImageSequence .Iterator (img ):
681+
682+ for i in ImageSequence .Iterator (img_rgb ):
657683 i = ImageOps .exif_transpose (i )
658684 if i .mode == 'I' :
659685 i = i .point (lambda i : i * (1 / 255 ))
660- image = i .convert ("RGB" )
661- image = np .array (image ).astype (np .float32 ) / 255.0
686+
687+ if i .mode != 'RGB' :
688+ i = i .convert ('RGB' )
689+
690+ image = np .array (i ).astype (np .float32 ) / 255.0
662691 image = torch .from_numpy (image )[None ,]
663692
664- if mask_channel == "alpha" and 'A' in i .getbands ():
665- mask = np .array (i .getchannel ('A' )).astype (np .float32 ) / 255.0
666- mask = 1. - torch .from_numpy (mask )
693+ if mask is not None :
694+ output_masks .append (mask .unsqueeze (0 ))
667695 elif mask_channel == "red" and 'R' in i .getbands ():
668696 mask = np .array (i .getchannel ('R' )).astype (np .float32 ) / 255.0
669- mask = torch .from_numpy (mask )
697+ output_masks . append ( torch .from_numpy (mask ). unsqueeze ( 0 ) )
670698 elif mask_channel == "green" and 'G' in i .getbands ():
671699 mask = np .array (i .getchannel ('G' )).astype (np .float32 ) / 255.0
672- mask = torch .from_numpy (mask )
700+ output_masks . append ( torch .from_numpy (mask ). unsqueeze ( 0 ) )
673701 elif mask_channel == "blue" and 'B' in i .getbands ():
674702 mask = np .array (i .getchannel ('B' )).astype (np .float32 ) / 255.0
675- mask = torch .from_numpy (mask )
703+ output_masks . append ( torch .from_numpy (mask ). unsqueeze ( 0 ) )
676704 else :
677- mask = torch .ones ((height , width ), dtype = torch .float32 , device = "cpu" )
705+ output_masks . append ( torch .ones ((1 , height , width ), dtype = torch .float32 , device = "cpu" ) )
678706
679707 output_images .append (image )
680- output_masks .append (mask .unsqueeze (0 ))
681708
682709 if len (output_images ) > 1 :
683710 output_image = torch .cat (output_images , dim = 0 )
@@ -700,15 +727,15 @@ def load_image(self, image, mask_channel="alpha", scale_by=1.0, resize_mode="lon
700727 return (empty_image , empty_mask , empty_mask_image , 64 , 64 )
701728
702729 @classmethod
703- def IS_CHANGED (cls , image , mask_channel = "alpha" , scale_by = 1.0 , resize_mode = "longest_side" , size = 0 , extra_pnginfo = None ):
730+ def IS_CHANGED (cls , image , mask_channel = "alpha" , upscale_method = "lanczos" , scale_by = 1.0 , resize_mode = "longest_side" , size = 0 , extra_pnginfo = None ):
704731 image_path = folder_paths .get_annotated_filepath (image )
705732 m = hashlib .sha256 ()
706733 with open (image_path , 'rb' ) as f :
707734 m .update (f .read ())
708735 return m .digest ().hex ()
709736
710737 @classmethod
711- def VALIDATE_INPUTS (cls , image , mask_channel = "alpha" , scale_by = 1.0 , resize_mode = "longest_side" , size = 0 , extra_pnginfo = None ):
738+ def VALIDATE_INPUTS (cls , image , mask_channel = "alpha" , upscale_method = "lanczos" , scale_by = 1.0 , resize_mode = "longest_side" , size = 0 , extra_pnginfo = None ):
712739 if not folder_paths .exists_annotated_filepath (image ):
713740 return f"Invalid image file: { image } "
714741
0 commit comments