Skip to content

Commit 70c372d

Browse files
authored
Add files via upload
1 parent 919fecf commit 70c372d

File tree

8 files changed

+326
-30
lines changed

8 files changed

+326
-30
lines changed

AILab_ClothSegment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG v1.5.0
1+
# ComfyUI-RMBG v1.6.0
22
# This custom node for ComfyUI provides functionality for background removal using various models,
33
# including RMBG-2.0, INSPYRENET, and BEN. It leverages deep learning techniques
44
# to process images and generate masks for background removal.
@@ -10,7 +10,7 @@
1010
# When using or modifying this code, please respect both the original model licenses
1111
# and this integration's license terms.
1212
#
13-
# Source: https://github.com/1038lab/ComfyUI-RMBG
13+
# Source: https://github.com/AILab-AI/ComfyUI-RMBG
1414

1515
import os
1616
import torch
@@ -276,4 +276,4 @@ def segment_clothes(self, images, process_res=1024, mask_blur=0, mask_offset=0,
276276

277277
NODE_DISPLAY_NAME_MAPPINGS = {
278278
"ClothesSegment": "Clothes Segment (RMBG)"
279-
}
279+
}

AILab_FaceSegment.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# ComfyUI-RMBG v1.6.0
2+
# This custom node for ComfyUI provides functionality for face parsing using Segformer model.
3+
#
4+
# This integration script follows GPL-3.0 License.
5+
# When using or modifying this code, please respect both the original model licenses
6+
# and this integration's license terms.
7+
#
8+
# Source: https://github.com/AILab-AI/ComfyUI-RMBG
9+
10+
import os
11+
import torch
12+
import torch.nn as nn
13+
import numpy as np
14+
from typing import Tuple, Union
15+
from PIL import Image, ImageFilter
16+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
17+
import folder_paths
18+
from huggingface_hub import hf_hub_download
19+
import shutil
20+
from torchvision import transforms
21+
22+
def pil2tensor(image: Image.Image) -> torch.Tensor:
23+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
24+
25+
def tensor2pil(image: torch.Tensor) -> Image.Image:
26+
return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
27+
28+
def image2mask(image: Image.Image) -> torch.Tensor:
29+
if isinstance(image, Image.Image):
30+
image = pil2tensor(image)
31+
return image.squeeze()[..., 0]
32+
33+
def mask2image(mask: torch.Tensor) -> Image.Image:
34+
if len(mask.shape) == 2:
35+
mask = mask.unsqueeze(0)
36+
return tensor2pil(mask)
37+
38+
def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Image.Image:
39+
if isinstance(mask, torch.Tensor):
40+
mask = mask2image(mask)
41+
if mask.size != image.size:
42+
mask = mask.resize(image.size, Image.Resampling.LANCZOS)
43+
return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L')))
44+
45+
device = "cuda" if torch.cuda.is_available() else "cpu"
46+
47+
folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))
48+
49+
AVAILABLE_MODELS = {
50+
"face_parsing": "1038lab/segformer_face"
51+
}
52+
53+
class FaceSegment:
54+
def __init__(self):
55+
self.processor = None
56+
self.model = None
57+
self.cache_dir = os.path.join(folder_paths.models_dir, "RMBG", "segformer_face")
58+
59+
@classmethod
60+
def INPUT_TYPES(cls):
61+
available_classes = [
62+
# "Background", # Not a facial feature
63+
"Skin", "Nose", "Eyeglasses", "Left-eye", "Right-eye",
64+
"Left-eyebrow", "Right-eyebrow", "Left-ear", "Right-ear", "Mouth",
65+
"Upper-lip", "Lower-lip", "Hair", "Earring", "Neck",
66+
# "Hat", # Not a facial feature
67+
# "Necklace", # Not a facial feature
68+
# "Clothing" # Not a facial feature
69+
]
70+
71+
tooltips = {
72+
"process_res": "Processing resolution (higher = more VRAM)",
73+
"mask_blur": "Blur amount for mask edges",
74+
"mask_offset": "Expand/Shrink mask boundary",
75+
"background_color": "Choose background color (Alpha = transparent)",
76+
"invert_output": "Invert both image and mask output",
77+
}
78+
79+
return {
80+
"required": {
81+
"images": ("IMAGE",),
82+
},
83+
"optional": {
84+
**{cls_name: ("BOOLEAN", {"default": False})
85+
for cls_name in available_classes},
86+
"process_res": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 32, "tooltip": tooltips["process_res"]}),
87+
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
88+
"mask_offset": ("INT", {"default": 0, "min": -20, "max": 20, "step": 1, "tooltip": tooltips["mask_offset"]}),
89+
"background_color": (["Alpha", "black", "white", "gray", "green", "blue", "red"], {"default": "Alpha", "tooltip": tooltips["background_color"]}),
90+
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
91+
},
92+
}
93+
94+
RETURN_TYPES = ("IMAGE", "MASK")
95+
RETURN_NAMES = ("images", "mask")
96+
FUNCTION = "segment_face"
97+
CATEGORY = "🧪AILab/🧽RMBG"
98+
99+
def check_model_cache(self):
100+
if not os.path.exists(self.cache_dir):
101+
return False, "Model directory not found"
102+
103+
required_files = [
104+
'config.json',
105+
'model.safetensors',
106+
'preprocessor_config.json'
107+
]
108+
109+
missing_files = [f for f in required_files if not os.path.exists(os.path.join(self.cache_dir, f))]
110+
if missing_files:
111+
return False, f"Required model files missing: {', '.join(missing_files)}"
112+
return True, "Model cache verified"
113+
114+
def clear_model(self):
115+
if self.model is not None:
116+
self.model.cpu()
117+
del self.model
118+
self.model = None
119+
self.processor = None
120+
torch.cuda.empty_cache()
121+
122+
def download_model_files(self):
123+
model_id = AVAILABLE_MODELS["face_parsing"]
124+
model_files = {
125+
'config.json': 'config.json',
126+
'model.safetensors': 'model.safetensors',
127+
'preprocessor_config.json': 'preprocessor_config.json'
128+
}
129+
130+
os.makedirs(self.cache_dir, exist_ok=True)
131+
print(f"Downloading face parsing model files...")
132+
133+
try:
134+
for save_name, repo_path in model_files.items():
135+
print(f"Downloading {save_name}...")
136+
downloaded_path = hf_hub_download(
137+
repo_id=model_id,
138+
filename=repo_path,
139+
local_dir=self.cache_dir,
140+
local_dir_use_symlinks=False
141+
)
142+
143+
if os.path.dirname(downloaded_path) != self.cache_dir:
144+
target_path = os.path.join(self.cache_dir, save_name)
145+
shutil.move(downloaded_path, target_path)
146+
return True, "Model files downloaded successfully"
147+
except Exception as e:
148+
return False, f"Error downloading model files: {str(e)}"
149+
150+
def segment_face(self, images, process_res=512, mask_blur=0, mask_offset=0, background_color="Alpha", invert_output=False, **class_selections):
151+
try:
152+
# Check and download model if needed
153+
cache_status, message = self.check_model_cache()
154+
if not cache_status:
155+
print(f"Cache check: {message}")
156+
download_status, download_message = self.download_model_files()
157+
if not download_status:
158+
raise RuntimeError(download_message)
159+
160+
# Load model if needed
161+
if self.processor is None:
162+
self.processor = SegformerImageProcessor.from_pretrained(self.cache_dir)
163+
self.model = AutoModelForSemanticSegmentation.from_pretrained(self.cache_dir)
164+
self.model.eval()
165+
for param in self.model.parameters():
166+
param.requires_grad = False
167+
self.model.to(device)
168+
169+
# Class mapping for segmentation
170+
class_map = {
171+
"Background": 0, "Skin": 1, "Nose": 2, "Eyeglasses": 3,
172+
"Left-eye": 4, "Right-eye": 5, "Left-eyebrow": 6, "Right-eyebrow": 7,
173+
"Left-ear": 8, "Right-ear": 9, "Mouth": 10, "Upper-lip": 11,
174+
"Lower-lip": 12, "Hair": 13, "Hat": 14, "Earring": 15,
175+
"Necklace": 16, "Neck": 17, "Clothing": 18
176+
}
177+
178+
# Get selected classes
179+
selected_classes = [name for name, selected in class_selections.items() if selected]
180+
if not selected_classes:
181+
selected_classes = ["Skin", "Nose", "Eyes", "Mouth"]
182+
183+
# Image preprocessing
184+
transform_image = transforms.Compose([
185+
transforms.Resize((process_res, process_res)),
186+
transforms.ToTensor(),
187+
])
188+
189+
batch_tensor = []
190+
batch_masks = []
191+
192+
for image in images:
193+
orig_image = tensor2pil(image)
194+
w, h = orig_image.size
195+
196+
input_tensor = transform_image(orig_image)
197+
198+
if input_tensor.shape[0] == 4:
199+
input_tensor = input_tensor[:3]
200+
201+
input_tensor = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input_tensor)
202+
203+
input_tensor = input_tensor.unsqueeze(0).to(device)
204+
205+
with torch.no_grad():
206+
outputs = self.model(input_tensor)
207+
logits = outputs.logits.cpu()
208+
upsampled_logits = nn.functional.interpolate(
209+
logits,
210+
size=(h, w),
211+
mode="bilinear",
212+
align_corners=False,
213+
)
214+
pred_seg = upsampled_logits.argmax(dim=1)[0]
215+
216+
# Combine selected class masks
217+
combined_mask = None
218+
for class_name in selected_classes:
219+
mask = (pred_seg == class_map[class_name]).float()
220+
if combined_mask is None:
221+
combined_mask = mask
222+
else:
223+
combined_mask = torch.clamp(combined_mask + mask, 0, 1)
224+
225+
# Convert mask to PIL for processing
226+
mask_image = Image.fromarray((combined_mask.numpy() * 255).astype(np.uint8))
227+
228+
if mask_blur > 0:
229+
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
230+
231+
if mask_offset != 0:
232+
if mask_offset > 0:
233+
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
234+
else:
235+
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
236+
237+
if invert_output:
238+
mask_image = Image.fromarray(255 - np.array(mask_image))
239+
240+
# Handle background color
241+
if background_color == "Alpha":
242+
rgba_image = RGB2RGBA(orig_image, mask_image)
243+
result_image = pil2tensor(rgba_image)
244+
else:
245+
bg_colors = {
246+
"black": (0, 0, 0),
247+
"white": (255, 255, 255),
248+
"gray": (128, 128, 128),
249+
"green": (0, 255, 0),
250+
"blue": (0, 0, 255),
251+
"red": (255, 0, 0)
252+
}
253+
254+
rgba_image = RGB2RGBA(orig_image, mask_image)
255+
bg_image = Image.new('RGBA', orig_image.size, (*bg_colors[background_color], 255))
256+
composite_image = Image.alpha_composite(bg_image, rgba_image)
257+
result_image = pil2tensor(composite_image.convert('RGB'))
258+
259+
batch_tensor.append(result_image)
260+
batch_masks.append(pil2tensor(mask_image))
261+
262+
# Prepare final output
263+
batch_tensor = torch.cat(batch_tensor, dim=0)
264+
batch_masks = torch.cat(batch_masks, dim=0)
265+
266+
return (batch_tensor, batch_masks)
267+
268+
except Exception as e:
269+
self.clear_model()
270+
raise RuntimeError(f"Error in Face Parsing processing: {str(e)}")
271+
finally:
272+
if not self.model.training:
273+
self.clear_model()
274+
275+
NODE_CLASS_MAPPINGS = {
276+
"FaceSegment": FaceSegment
277+
}
278+
279+
NODE_DISPLAY_NAME_MAPPINGS = {
280+
"FaceSegment": "Face Segment (RMBG)"
281+
}

AILab_FashionSegment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG v1.5.0
1+
# ComfyUI-RMBG v1.6.0
22
# This custom node for ComfyUI provides functionality for fashion segmentation using segformer-b3-fashion model.
33
# It leverages deep learning techniques to process images and generate masks for fashion items segmentation.
44

@@ -9,7 +9,7 @@
99
# When using or modifying this code, please respect both the original model licenses
1010
# and this integration's license terms.
1111
#
12-
# Source: https://github.com/1038lab/ComfyUI-RMBG
12+
# Source: https://github.com/AILab-AI/ComfyUI-RMBG
1313

1414
import os
1515
import torch
@@ -354,4 +354,4 @@ def __del__(self):
354354
NODE_DISPLAY_NAME_MAPPINGS = {
355355
"FashionSegmentAccessories": "Accessories Segment (RMBG)",
356356
"FashionSegmentClothing": "Fashion Segment (RMBG)"
357-
}
357+
}

AILab_RMBG.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG v1.5.0
1+
# ComfyUI-RMBG v1.6.0
22
# This custom node for ComfyUI provides functionality for background removal using various models,
33
# including RMBG-2.0, INSPYRENET, and BEN. It leverages deep learning techniques
44
# to process images and generate masks for background removal.
@@ -12,7 +12,7 @@
1212
# When using or modifying this code, please respect both the original model licenses
1313
# and this integration's license terms.
1414
#
15-
# Source: https://github.com/1038lab/ComfyUI-RMBG
15+
# Source: https://github.com/AILab-AI/ComfyUI-RMBG
1616

1717
import os
1818
import torch
@@ -435,4 +435,4 @@ def process_image(self, image, model, **params):
435435

436436
NODE_DISPLAY_NAME_MAPPINGS = {
437437
"RMBG": "Remove Background (RMBG)"
438-
}
438+
}

AILab_Segment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG v1.5.0
1+
# ComfyUI-RMBG v1.6.0
22
# This custom node for ComfyUI provides functionality for background removal using various models,
33
# including RMBG-2.0, INSPYRENET, and BEN. It leverages deep learning techniques
44
# to process images and generate masks for background removal.
@@ -11,7 +11,7 @@
1111
# When using or modifying this code, please respect both the original model licenses
1212
# and this integration's license terms.
1313
#
14-
# Source: https://github.com/1038lab/ComfyUI-RMBG
14+
# Source: https://github.com/AILab-AI/ComfyUI-RMBG
1515

1616
import os
1717
import sys
@@ -345,4 +345,4 @@ def get_local_filepath(self, url, dirname, local_file_name=None):
345345

346346
NODE_DISPLAY_NAME_MAPPINGS = {
347347
"Segment": "Segment (RMBG)"
348-
}
348+
}

0 commit comments

Comments
 (0)