Skip to content

Commit b202a92

Browse files
authored
Add files via upload
1 parent 5640586 commit b202a92

File tree

1 file changed

+50
-33
lines changed

1 file changed

+50
-33
lines changed

AILab_RMBG.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG
1+
# ComfyUI-RMBG V2.9.1
22
# This custom node for ComfyUI provides functionality for background removal using various models,
33
# including RMBG-2.0, INSPYRENET, BEN, BEN2 and BIREFNET-HR. It leverages deep learning techniques
44
# to process images and generate masks for background removal.
@@ -572,49 +572,51 @@ def process_image(self, image, model, **params):
572572
handle_model_error(download_message)
573573
print("Model files downloaded successfully")
574574

575-
for img in image:
576-
mask = model_instance.process_image(img, model, params)
577-
575+
model_type = AVAILABLE_MODELS[model]["type"]
576+
577+
def _process_pair(img, mask):
578578
if isinstance(mask, list):
579579
masks = [m.convert("L") for m in mask if isinstance(m, Image.Image)]
580-
mask = masks[0] if masks else None
580+
mask_local = masks[0] if masks else None
581581
elif isinstance(mask, Image.Image):
582-
mask = mask.convert("L")
583-
584-
mask_tensor = pil2tensor(mask)
585-
mask_tensor = mask_tensor * (1 + (1 - params["sensitivity"]))
586-
mask_tensor = torch.clamp(mask_tensor, 0, 1)
587-
mask = tensor2pil(mask_tensor)
582+
mask_local = mask.convert("L")
583+
else:
584+
mask_local = mask
585+
586+
mask_tensor_local = pil2tensor(mask_local)
587+
mask_tensor_local = mask_tensor_local * (1 + (1 - params["sensitivity"]))
588+
mask_tensor_local = torch.clamp(mask_tensor_local, 0, 1)
589+
mask_img_local = tensor2pil(mask_tensor_local)
588590

589591
if params["mask_blur"] > 0:
590-
mask = mask.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"]))
592+
mask_img_local = mask_img_local.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"]))
591593

592594
if params["mask_offset"] != 0:
593595
if params["mask_offset"] > 0:
594596
for _ in range(params["mask_offset"]):
595-
mask = mask.filter(ImageFilter.MaxFilter(3))
597+
mask_img_local = mask_img_local.filter(ImageFilter.MaxFilter(3))
596598
else:
597599
for _ in range(-params["mask_offset"]):
598-
mask = mask.filter(ImageFilter.MinFilter(3))
600+
mask_img_local = mask_img_local.filter(ImageFilter.MinFilter(3))
599601

600602
if params["invert_output"]:
601-
mask = Image.fromarray(255 - np.array(mask))
602-
603-
img_tensor = torch.from_numpy(np.array(tensor2pil(img))).permute(2, 0, 1).unsqueeze(0) / 255.0
604-
mask_tensor = torch.from_numpy(np.array(mask)).unsqueeze(0).unsqueeze(0) / 255.0
605-
606-
orig_image = tensor2pil(img)
603+
mask_img_local = Image.fromarray(255 - np.array(mask_img_local))
604+
605+
img_tensor_local = torch.from_numpy(np.array(tensor2pil(img))).permute(2, 0, 1).unsqueeze(0) / 255.0
606+
mask_tensor_b1hw = torch.from_numpy(np.array(mask_img_local)).unsqueeze(0).unsqueeze(0) / 255.0
607+
608+
orig_image_local = tensor2pil(img)
607609

608610
if params.get("refine_foreground", False):
609-
refined_fg = refine_foreground(img_tensor, mask_tensor)
610-
refined_fg = tensor2pil(refined_fg[0].permute(1, 2, 0))
611-
r, g, b = refined_fg.split()
612-
foreground = Image.merge('RGBA', (r, g, b, mask))
611+
refined_fg_local = refine_foreground(img_tensor_local, mask_tensor_b1hw)
612+
refined_fg_local = tensor2pil(refined_fg_local[0].permute(1, 2, 0))
613+
r, g, b = refined_fg_local.split()
614+
foreground_local = Image.merge('RGBA', (r, g, b, mask_img_local))
613615
else:
614-
orig_rgba = orig_image.convert("RGBA")
615-
r, g, b, _ = orig_rgba.split()
616-
foreground = Image.merge('RGBA', (r, g, b, mask))
617-
616+
orig_rgba_local = orig_image_local.convert("RGBA")
617+
r, g, b, _ = orig_rgba_local.split()
618+
foreground_local = Image.merge('RGBA', (r, g, b, mask_img_local))
619+
618620
if params["background"] == "Color":
619621
def hex_to_rgba(hex_color):
620622
hex_color = hex_color.lstrip('#')
@@ -628,14 +630,29 @@ def hex_to_rgba(hex_color):
628630
return (r, g, b, a)
629631
background_color = params.get("background_color", "#222222")
630632
rgba = hex_to_rgba(background_color)
631-
bg_image = Image.new('RGBA', orig_image.size, rgba)
632-
composite_image = Image.alpha_composite(bg_image, foreground)
633+
bg_image = Image.new('RGBA', orig_image_local.size, rgba)
634+
composite_image = Image.alpha_composite(bg_image, foreground_local)
633635
processed_images.append(pil2tensor(composite_image.convert("RGB")))
634636
else:
635-
processed_images.append(pil2tensor(foreground))
637+
processed_images.append(pil2tensor(foreground_local))
636638

637-
processed_masks.append(pil2tensor(mask))
638-
639+
processed_masks.append(pil2tensor(mask_img_local))
640+
641+
if model_type in ("rmbg", "ben2"):
642+
images_list = [img for img in image]
643+
chunk_size = 4
644+
for start in range(0, len(images_list), chunk_size):
645+
batch_imgs = images_list[start:start + chunk_size]
646+
masks = model_instance.process_image(batch_imgs, model, params)
647+
if isinstance(masks, Image.Image):
648+
masks = [masks]
649+
for img_item, mask_item in zip(batch_imgs, masks):
650+
_process_pair(img_item, mask_item)
651+
else:
652+
for img in image:
653+
mask = model_instance.process_image(img, model, params)
654+
_process_pair(img, mask)
655+
639656
mask_images = []
640657
for mask_tensor in processed_masks:
641658
mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)

0 commit comments

Comments
 (0)