1+ # ComfyUI-RMBG v1.6.0
2+ # This custom node for ComfyUI provides functionality for face parsing using Segformer model.
3+ #
4+ # This integration script follows GPL-3.0 License.
5+ # When using or modifying this code, please respect both the original model licenses
6+ # and this integration's license terms.
7+ #
8+ # Source: https://github.com/AILab-AI/ComfyUI-RMBG
9+
10+ import os
11+ import torch
12+ import torch .nn as nn
13+ import numpy as np
14+ from typing import Tuple , Union
15+ from PIL import Image , ImageFilter
16+ from transformers import SegformerImageProcessor , AutoModelForSemanticSegmentation
17+ import folder_paths
18+ from huggingface_hub import hf_hub_download
19+ import shutil
20+ from torchvision import transforms
21+
22+ def pil2tensor (image : Image .Image ) -> torch .Tensor :
23+ return torch .from_numpy (np .array (image ).astype (np .float32 ) / 255.0 )[None ,]
24+
25+ def tensor2pil (image : torch .Tensor ) -> Image .Image :
26+ return Image .fromarray (np .clip (255. * image .cpu ().numpy (), 0 , 255 ).astype (np .uint8 ))
27+
28+ def image2mask (image : Image .Image ) -> torch .Tensor :
29+ if isinstance (image , Image .Image ):
30+ image = pil2tensor (image )
31+ return image .squeeze ()[..., 0 ]
32+
33+ def mask2image (mask : torch .Tensor ) -> Image .Image :
34+ if len (mask .shape ) == 2 :
35+ mask = mask .unsqueeze (0 )
36+ return tensor2pil (mask )
37+
38+ def RGB2RGBA (image : Image .Image , mask : Union [Image .Image , torch .Tensor ]) -> Image .Image :
39+ if isinstance (mask , torch .Tensor ):
40+ mask = mask2image (mask )
41+ if mask .size != image .size :
42+ mask = mask .resize (image .size , Image .Resampling .LANCZOS )
43+ return Image .merge ('RGBA' , (* image .convert ('RGB' ).split (), mask .convert ('L' )))
44+
45+ device = "cuda" if torch .cuda .is_available () else "cpu"
46+
47+ folder_paths .add_model_folder_path ("rmbg" , os .path .join (folder_paths .models_dir , "RMBG" ))
48+
49+ AVAILABLE_MODELS = {
50+ "face_parsing" : "1038lab/segformer_face"
51+ }
52+
53+ class FaceSegment :
54+ def __init__ (self ):
55+ self .processor = None
56+ self .model = None
57+ self .cache_dir = os .path .join (folder_paths .models_dir , "RMBG" , "segformer_face" )
58+
59+ @classmethod
60+ def INPUT_TYPES (cls ):
61+ available_classes = [
62+ # "Background", # Not a facial feature
63+ "Skin" , "Nose" , "Eyeglasses" , "Left-eye" , "Right-eye" ,
64+ "Left-eyebrow" , "Right-eyebrow" , "Left-ear" , "Right-ear" , "Mouth" ,
65+ "Upper-lip" , "Lower-lip" , "Hair" , "Earring" , "Neck" ,
66+ # "Hat", # Not a facial feature
67+ # "Necklace", # Not a facial feature
68+ # "Clothing" # Not a facial feature
69+ ]
70+
71+ tooltips = {
72+ "process_res" : "Processing resolution (higher = more VRAM)" ,
73+ "mask_blur" : "Blur amount for mask edges" ,
74+ "mask_offset" : "Expand/Shrink mask boundary" ,
75+ "background_color" : "Choose background color (Alpha = transparent)" ,
76+ "invert_output" : "Invert both image and mask output" ,
77+ }
78+
79+ return {
80+ "required" : {
81+ "images" : ("IMAGE" ,),
82+ },
83+ "optional" : {
84+ ** {cls_name : ("BOOLEAN" , {"default" : False })
85+ for cls_name in available_classes },
86+ "process_res" : ("INT" , {"default" : 512 , "min" : 128 , "max" : 2048 , "step" : 32 , "tooltip" : tooltips ["process_res" ]}),
87+ "mask_blur" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 64 , "step" : 1 , "tooltip" : tooltips ["mask_blur" ]}),
88+ "mask_offset" : ("INT" , {"default" : 0 , "min" : - 20 , "max" : 20 , "step" : 1 , "tooltip" : tooltips ["mask_offset" ]}),
89+ "background_color" : (["Alpha" , "black" , "white" , "gray" , "green" , "blue" , "red" ], {"default" : "Alpha" , "tooltip" : tooltips ["background_color" ]}),
90+ "invert_output" : ("BOOLEAN" , {"default" : False , "tooltip" : tooltips ["invert_output" ]}),
91+ },
92+ }
93+
94+ RETURN_TYPES = ("IMAGE" , "MASK" )
95+ RETURN_NAMES = ("images" , "mask" )
96+ FUNCTION = "segment_face"
97+ CATEGORY = "🧪AILab/🧽RMBG"
98+
99+ def check_model_cache (self ):
100+ if not os .path .exists (self .cache_dir ):
101+ return False , "Model directory not found"
102+
103+ required_files = [
104+ 'config.json' ,
105+ 'model.safetensors' ,
106+ 'preprocessor_config.json'
107+ ]
108+
109+ missing_files = [f for f in required_files if not os .path .exists (os .path .join (self .cache_dir , f ))]
110+ if missing_files :
111+ return False , f"Required model files missing: { ', ' .join (missing_files )} "
112+ return True , "Model cache verified"
113+
114+ def clear_model (self ):
115+ if self .model is not None :
116+ self .model .cpu ()
117+ del self .model
118+ self .model = None
119+ self .processor = None
120+ torch .cuda .empty_cache ()
121+
122+ def download_model_files (self ):
123+ model_id = AVAILABLE_MODELS ["face_parsing" ]
124+ model_files = {
125+ 'config.json' : 'config.json' ,
126+ 'model.safetensors' : 'model.safetensors' ,
127+ 'preprocessor_config.json' : 'preprocessor_config.json'
128+ }
129+
130+ os .makedirs (self .cache_dir , exist_ok = True )
131+ print (f"Downloading face parsing model files..." )
132+
133+ try :
134+ for save_name , repo_path in model_files .items ():
135+ print (f"Downloading { save_name } ..." )
136+ downloaded_path = hf_hub_download (
137+ repo_id = model_id ,
138+ filename = repo_path ,
139+ local_dir = self .cache_dir ,
140+ local_dir_use_symlinks = False
141+ )
142+
143+ if os .path .dirname (downloaded_path ) != self .cache_dir :
144+ target_path = os .path .join (self .cache_dir , save_name )
145+ shutil .move (downloaded_path , target_path )
146+ return True , "Model files downloaded successfully"
147+ except Exception as e :
148+ return False , f"Error downloading model files: { str (e )} "
149+
150+ def segment_face (self , images , process_res = 512 , mask_blur = 0 , mask_offset = 0 , background_color = "Alpha" , invert_output = False , ** class_selections ):
151+ try :
152+ # Check and download model if needed
153+ cache_status , message = self .check_model_cache ()
154+ if not cache_status :
155+ print (f"Cache check: { message } " )
156+ download_status , download_message = self .download_model_files ()
157+ if not download_status :
158+ raise RuntimeError (download_message )
159+
160+ # Load model if needed
161+ if self .processor is None :
162+ self .processor = SegformerImageProcessor .from_pretrained (self .cache_dir )
163+ self .model = AutoModelForSemanticSegmentation .from_pretrained (self .cache_dir )
164+ self .model .eval ()
165+ for param in self .model .parameters ():
166+ param .requires_grad = False
167+ self .model .to (device )
168+
169+ # Class mapping for segmentation
170+ class_map = {
171+ "Background" : 0 , "Skin" : 1 , "Nose" : 2 , "Eyeglasses" : 3 ,
172+ "Left-eye" : 4 , "Right-eye" : 5 , "Left-eyebrow" : 6 , "Right-eyebrow" : 7 ,
173+ "Left-ear" : 8 , "Right-ear" : 9 , "Mouth" : 10 , "Upper-lip" : 11 ,
174+ "Lower-lip" : 12 , "Hair" : 13 , "Hat" : 14 , "Earring" : 15 ,
175+ "Necklace" : 16 , "Neck" : 17 , "Clothing" : 18
176+ }
177+
178+ # Get selected classes
179+ selected_classes = [name for name , selected in class_selections .items () if selected ]
180+ if not selected_classes :
181+ selected_classes = ["Skin" , "Nose" , "Eyes" , "Mouth" ]
182+
183+ # Image preprocessing
184+ transform_image = transforms .Compose ([
185+ transforms .Resize ((process_res , process_res )),
186+ transforms .ToTensor (),
187+ ])
188+
189+ batch_tensor = []
190+ batch_masks = []
191+
192+ for image in images :
193+ orig_image = tensor2pil (image )
194+ w , h = orig_image .size
195+
196+ input_tensor = transform_image (orig_image )
197+
198+ if input_tensor .shape [0 ] == 4 :
199+ input_tensor = input_tensor [:3 ]
200+
201+ input_tensor = transforms .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ])(input_tensor )
202+
203+ input_tensor = input_tensor .unsqueeze (0 ).to (device )
204+
205+ with torch .no_grad ():
206+ outputs = self .model (input_tensor )
207+ logits = outputs .logits .cpu ()
208+ upsampled_logits = nn .functional .interpolate (
209+ logits ,
210+ size = (h , w ),
211+ mode = "bilinear" ,
212+ align_corners = False ,
213+ )
214+ pred_seg = upsampled_logits .argmax (dim = 1 )[0 ]
215+
216+ # Combine selected class masks
217+ combined_mask = None
218+ for class_name in selected_classes :
219+ mask = (pred_seg == class_map [class_name ]).float ()
220+ if combined_mask is None :
221+ combined_mask = mask
222+ else :
223+ combined_mask = torch .clamp (combined_mask + mask , 0 , 1 )
224+
225+ # Convert mask to PIL for processing
226+ mask_image = Image .fromarray ((combined_mask .numpy () * 255 ).astype (np .uint8 ))
227+
228+ if mask_blur > 0 :
229+ mask_image = mask_image .filter (ImageFilter .GaussianBlur (radius = mask_blur ))
230+
231+ if mask_offset != 0 :
232+ if mask_offset > 0 :
233+ mask_image = mask_image .filter (ImageFilter .MaxFilter (size = mask_offset * 2 + 1 ))
234+ else :
235+ mask_image = mask_image .filter (ImageFilter .MinFilter (size = - mask_offset * 2 + 1 ))
236+
237+ if invert_output :
238+ mask_image = Image .fromarray (255 - np .array (mask_image ))
239+
240+ # Handle background color
241+ if background_color == "Alpha" :
242+ rgba_image = RGB2RGBA (orig_image , mask_image )
243+ result_image = pil2tensor (rgba_image )
244+ else :
245+ bg_colors = {
246+ "black" : (0 , 0 , 0 ),
247+ "white" : (255 , 255 , 255 ),
248+ "gray" : (128 , 128 , 128 ),
249+ "green" : (0 , 255 , 0 ),
250+ "blue" : (0 , 0 , 255 ),
251+ "red" : (255 , 0 , 0 )
252+ }
253+
254+ rgba_image = RGB2RGBA (orig_image , mask_image )
255+ bg_image = Image .new ('RGBA' , orig_image .size , (* bg_colors [background_color ], 255 ))
256+ composite_image = Image .alpha_composite (bg_image , rgba_image )
257+ result_image = pil2tensor (composite_image .convert ('RGB' ))
258+
259+ batch_tensor .append (result_image )
260+ batch_masks .append (pil2tensor (mask_image ))
261+
262+ # Prepare final output
263+ batch_tensor = torch .cat (batch_tensor , dim = 0 )
264+ batch_masks = torch .cat (batch_masks , dim = 0 )
265+
266+ return (batch_tensor , batch_masks )
267+
268+ except Exception as e :
269+ self .clear_model ()
270+ raise RuntimeError (f"Error in Face Parsing processing: { str (e )} " )
271+ finally :
272+ if not self .model .training :
273+ self .clear_model ()
274+
275+ NODE_CLASS_MAPPINGS = {
276+ "FaceSegment" : FaceSegment
277+ }
278+
279+ NODE_DISPLAY_NAME_MAPPINGS = {
280+ "FaceSegment" : "Face Segment (RMBG)"
281+ }
0 commit comments