Skip to content

Commit f09fd46

Browse files
authored
Add files via upload
1 parent c7afaf2 commit f09fd46

File tree

8 files changed

+293
-24
lines changed

8 files changed

+293
-24
lines changed

AILab_ImageMaskTools.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG v2.5.0
1+
# ComfyUI-RMBG v2.6.0
22
#
33
# This node facilitates background removal using various models, including RMBG-2.0, INSPYRENET, BEN, BEN2, and BIREFNET-HR.
44
# It utilizes advanced deep learning techniques to process images and generate accurate masks for background removal.
@@ -31,7 +31,8 @@
3131
#
3232
# 5. Input Nodes:
3333
# - ColorInput: A node for inputting colors in various formats.
34-
34+
#
35+
# License: GPL-3.0
3536
# These nodes are crafted to streamline common image and mask operations within ComfyUI workflows.
3637

3738
import os
@@ -587,6 +588,8 @@ def _resize_if_needed(self, mask, target_shape):
587588

588589
# Image loader node
589590
class AILab_LoadImage:
591+
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
592+
590593
@classmethod
591594
def INPUT_TYPES(cls):
592595
input_dir = folder_paths.get_input_directory()
@@ -596,6 +599,7 @@ def INPUT_TYPES(cls):
596599
"required": {
597600
"image": (sorted(files) or [""], {"image_upload": True}),
598601
"mask_channel": (["alpha", "red", "green", "blue"], {"default": "alpha", "tooltip": "Select channel to extract mask from"}),
602+
"upscale_method": (cls.upscale_methods, {"default": "lanczos", "tooltip": "Method used for resizing the image"}),
599603
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01, "tooltip": "Scale image by this factor (ignored if size > 0)"}),
600604
"resize_mode": (["longest_side", "shortest_side", "width", "height"], {"default": "longest_side", "tooltip": "Choose how to resize the image"}),
601605
"size": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Target size for the selected resize mode (0 = keep original size)"}),
@@ -611,14 +615,28 @@ def INPUT_TYPES(cls):
611615
FUNCTION = "load_image"
612616
OUTPUT_NODE = False
613617

614-
def load_image(self, image, mask_channel="alpha", scale_by=1.0, resize_mode="longest_side", size=0, extra_pnginfo=None):
618+
def load_image(self, image, mask_channel="alpha", upscale_method="lanczos", scale_by=1.0, resize_mode="longest_side", size=0, extra_pnginfo=None):
615619
try:
616620
image_path = folder_paths.get_annotated_filepath(image)
617621
img = Image.open(image_path)
618622

619623
orig_width, orig_height = img.size
620624

621-
# Image resizing logic
625+
resampling_map = {
626+
"nearest-exact": Image.NEAREST,
627+
"bilinear": Image.BILINEAR,
628+
"area": Image.BOX,
629+
"bicubic": Image.BICUBIC,
630+
"lanczos": Image.LANCZOS
631+
}
632+
resampling = resampling_map.get(upscale_method, Image.LANCZOS)
633+
634+
has_alpha = 'A' in img.getbands()
635+
if has_alpha and mask_channel == "alpha":
636+
original_alpha = img.getchannel('A')
637+
638+
img_rgb = img.convert('RGB')
639+
622640
if size > 0:
623641
if resize_mode == "longest_side":
624642
if orig_width >= orig_height:
@@ -627,57 +645,66 @@ def load_image(self, image, mask_channel="alpha", scale_by=1.0, resize_mode="lon
627645
else:
628646
new_height = size
629647
new_width = int(orig_width * (size / orig_height))
630-
img = img.resize((new_width, new_height), Image.LANCZOS)
648+
img_rgb = img_rgb.resize((new_width, new_height), resampling)
631649
elif resize_mode == "shortest_side":
632650
if orig_width <= orig_height:
633651
new_width = size
634652
new_height = int(orig_height * (size / orig_width))
635653
else:
636654
new_height = size
637655
new_width = int(orig_width * (size / orig_height))
638-
img = img.resize((new_width, new_height), Image.LANCZOS)
656+
img_rgb = img_rgb.resize((new_width, new_height), resampling)
639657
elif resize_mode == "width":
640658
new_width = size
641659
new_height = int(orig_height * (size / orig_width))
642-
img = img.resize((new_width, new_height), Image.LANCZOS)
660+
img_rgb = img_rgb.resize((new_width, new_height), resampling)
643661
elif resize_mode == "height":
644662
new_height = size
645663
new_width = int(orig_width * (size / orig_height))
646-
img = img.resize((new_width, new_height), Image.LANCZOS)
664+
img_rgb = img_rgb.resize((new_width, new_height), resampling)
647665
elif scale_by != 1.0:
648666
new_width = int(orig_width * scale_by)
649667
new_height = int(orig_height * scale_by)
650-
img = img.resize((new_width, new_height), Image.LANCZOS)
668+
img_rgb = img_rgb.resize((new_width, new_height), resampling)
651669

652-
width, height = img.size
670+
width, height = img_rgb.size
671+
672+
mask = None
673+
if mask_channel == "alpha" and has_alpha:
674+
if (size > 0 or scale_by != 1.0) and 'original_alpha' in locals():
675+
mask_img = original_alpha.resize((width, height), resampling)
676+
mask = np.array(mask_img).astype(np.float32) / 255.0
677+
mask = 1. - torch.from_numpy(mask)
653678

654679
output_images = []
655680
output_masks = []
656-
for i in ImageSequence.Iterator(img):
681+
682+
for i in ImageSequence.Iterator(img_rgb):
657683
i = ImageOps.exif_transpose(i)
658684
if i.mode == 'I':
659685
i = i.point(lambda i: i * (1 / 255))
660-
image = i.convert("RGB")
661-
image = np.array(image).astype(np.float32) / 255.0
686+
687+
if i.mode != 'RGB':
688+
i = i.convert('RGB')
689+
690+
image = np.array(i).astype(np.float32) / 255.0
662691
image = torch.from_numpy(image)[None,]
663692

664-
if mask_channel == "alpha" and 'A' in i.getbands():
665-
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
666-
mask = 1. - torch.from_numpy(mask)
693+
if mask is not None:
694+
output_masks.append(mask.unsqueeze(0))
667695
elif mask_channel == "red" and 'R' in i.getbands():
668696
mask = np.array(i.getchannel('R')).astype(np.float32) / 255.0
669-
mask = torch.from_numpy(mask)
697+
output_masks.append(torch.from_numpy(mask).unsqueeze(0))
670698
elif mask_channel == "green" and 'G' in i.getbands():
671699
mask = np.array(i.getchannel('G')).astype(np.float32) / 255.0
672-
mask = torch.from_numpy(mask)
700+
output_masks.append(torch.from_numpy(mask).unsqueeze(0))
673701
elif mask_channel == "blue" and 'B' in i.getbands():
674702
mask = np.array(i.getchannel('B')).astype(np.float32) / 255.0
675-
mask = torch.from_numpy(mask)
703+
output_masks.append(torch.from_numpy(mask).unsqueeze(0))
676704
else:
677-
mask = torch.ones((height, width), dtype=torch.float32, device="cpu")
705+
output_masks.append(torch.ones((1, height, width), dtype=torch.float32, device="cpu"))
678706

679707
output_images.append(image)
680-
output_masks.append(mask.unsqueeze(0))
681708

682709
if len(output_images) > 1:
683710
output_image = torch.cat(output_images, dim=0)
@@ -700,15 +727,15 @@ def load_image(self, image, mask_channel="alpha", scale_by=1.0, resize_mode="lon
700727
return (empty_image, empty_mask, empty_mask_image, 64, 64)
701728

702729
@classmethod
703-
def IS_CHANGED(cls, image, mask_channel="alpha", scale_by=1.0, resize_mode="longest_side", size=0, extra_pnginfo=None):
730+
def IS_CHANGED(cls, image, mask_channel="alpha", upscale_method="lanczos", scale_by=1.0, resize_mode="longest_side", size=0, extra_pnginfo=None):
704731
image_path = folder_paths.get_annotated_filepath(image)
705732
m = hashlib.sha256()
706733
with open(image_path, 'rb') as f:
707734
m.update(f.read())
708735
return m.digest().hex()
709736

710737
@classmethod
711-
def VALIDATE_INPUTS(cls, image, mask_channel="alpha", scale_by=1.0, resize_mode="longest_side", size=0, extra_pnginfo=None):
738+
def VALIDATE_INPUTS(cls, image, mask_channel="alpha", upscale_method="lanczos", scale_by=1.0, resize_mode="longest_side", size=0, extra_pnginfo=None):
712739
if not folder_paths.exists_annotated_filepath(image):
713740
return f"Invalid image file: {image}"
714741

AILab_InpaintTools.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# ComfyUI-RMBG v2.6.0
2+
#
3+
# AILab Inpaint Tools
4+
# A collection of specialized nodes for inpainting tasks in ComfyUI.
5+
# Features a set of utilities for mask processing, latent conditioning, and inpainting workflows.
6+
#
7+
# 1. Inpaint Nodes:
8+
# - AILab_ReferenceLatentMask: A node for inpainting tasks with the Flux Kontext model, using a reference latent and mask for precise region conditioning
9+
#
10+
# License: GPL-3.0
11+
# These nodes are crafted to streamline common image and mask operations within ComfyUI workflows.
12+
13+
import torch
14+
import node_helpers
15+
16+
def expand_mask(mask, expand_amount):
17+
if expand_amount == 0:
18+
return mask
19+
20+
import torch.nn.functional as F
21+
22+
binary_mask = (mask > 0.5).float()
23+
kernel_size = abs(expand_amount) * 2 + 1
24+
kernel_size = max(3, kernel_size)
25+
26+
kernel = torch.ones(1, 1, kernel_size, kernel_size, device=mask.device)
27+
28+
if expand_amount > 0:
29+
expanded = F.conv2d(
30+
binary_mask.reshape(-1, 1, mask.shape[-2], mask.shape[-1]),
31+
kernel,
32+
padding=kernel_size // 2
33+
)
34+
result = (expanded > 0).float()
35+
else:
36+
eroded = F.conv2d(
37+
binary_mask.reshape(-1, 1, mask.shape[-2], mask.shape[-1]),
38+
kernel,
39+
padding=kernel_size // 2
40+
)
41+
result = (eroded >= kernel_size * kernel_size).float()
42+
43+
if len(mask.shape) == 3:
44+
result = result.squeeze(1)
45+
46+
return result
47+
48+
49+
def blur_mask(mask, blur_amount):
50+
if blur_amount == 0:
51+
return mask
52+
53+
import torch.nn.functional as F
54+
import math
55+
56+
x = mask.reshape(-1, 1, mask.shape[-2], mask.shape[-1])
57+
kernel_size = max(3, math.ceil(blur_amount * 3) * 2 + 1)
58+
59+
sigma = blur_amount
60+
half_kernel = kernel_size // 2
61+
grid = torch.arange(-half_kernel, half_kernel + 1, device=mask.device).float()
62+
63+
gaussian = torch.exp(-0.5 * (grid / sigma) ** 2)
64+
gaussian = gaussian / gaussian.sum()
65+
66+
gaussian_x = gaussian.view(1, 1, 1, kernel_size)
67+
gaussian_y = gaussian.view(1, 1, kernel_size, 1)
68+
69+
blurred = F.conv2d(x, gaussian_x, padding=(0, half_kernel))
70+
blurred = F.conv2d(blurred, gaussian_y, padding=(half_kernel, 0))
71+
72+
if len(mask.shape) == 3:
73+
blurred = blurred.squeeze(1)
74+
75+
return blurred
76+
77+
class AILab_ReferenceLatentMask:
78+
@classmethod
79+
def INPUT_TYPES(cls):
80+
tooltips = {
81+
"conditioning": "Base conditioning input for inpainting task",
82+
"latent": "Encoded latent from VAE",
83+
"mask": "Area to inpaint (white regions)",
84+
"expand": "Grow mask (+) or shrink mask (-)",
85+
"blur": "Soften mask edges",
86+
"mask_only": "Only generate content in masked area"
87+
}
88+
89+
return {
90+
"required": {
91+
"conditioning": ("CONDITIONING", {"tooltip": tooltips["conditioning"]}),
92+
"latent": ("LATENT", {"tooltip": tooltips["latent"]}),
93+
"mask": ("MASK", {"tooltip": tooltips["mask"]}),
94+
"expand": ("INT", {"default": 5, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["expand"]}),
95+
"blur": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 64.0, "step": 0.1, "tooltip": tooltips["blur"]}),
96+
"mask_only": ("BOOLEAN", {"default": True, "tooltip": tooltips["mask_only"]}),
97+
}
98+
}
99+
100+
RETURN_TYPES = ("CONDITIONING", "LATENT", "MASK")
101+
RETURN_NAMES = ("CONDITIONING", "LATENT", "MASK")
102+
FUNCTION = "prepare_inpaint_conditioning"
103+
CATEGORY = "🧪AILab/🧽RMBG/🎭Inpaint"
104+
105+
def add_latent_to_conditioning(self, conditioning, latent=None):
106+
if latent is not None:
107+
return node_helpers.conditioning_set_values(
108+
conditioning,
109+
{"reference_latents": [latent["samples"]]},
110+
append=True
111+
)
112+
return conditioning
113+
114+
def prepare_inpaint_conditioning(self, conditioning, latent, mask, expand=5, blur=3.0, mask_only=True):
115+
processed_mask = mask
116+
117+
if expand != 0:
118+
processed_mask = expand_mask(processed_mask, expand)
119+
120+
if blur > 0:
121+
processed_mask = blur_mask(processed_mask, blur)
122+
123+
modified_cond = node_helpers.conditioning_set_values(
124+
conditioning,
125+
{
126+
"concat_latent_image": latent["samples"],
127+
"concat_mask": processed_mask
128+
}
129+
)
130+
131+
final_cond = self.add_latent_to_conditioning(modified_cond, latent)
132+
133+
output_latent = {"samples": latent["samples"]}
134+
if mask_only:
135+
output_latent["noise_mask"] = processed_mask
136+
137+
return (final_cond, output_latent, processed_mask)
138+
139+
140+
NODE_CLASS_MAPPINGS = {
141+
"AILab_ReferenceLatentMask": AILab_ReferenceLatentMask,
142+
}
143+
144+
NODE_DISPLAY_NAME_MAPPINGS = {
145+
"AILab_ReferenceLatentMask": "Reference Latent Mask (RMBG) 🖼️🎭",
146+
}

locales/en/nodeDefs.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,5 +416,21 @@
416416
"2": { "name": "WIDTH" },
417417
"3": { "name": "HEIGHT" }
418418
}
419+
},
420+
"AILab_ReferenceLatentMask": {
421+
"display_name": "Kontext Reference Latent Mask (RMBG) 🎭",
422+
"inputs": {
423+
"conditioning": { "name": "Conditioning" },
424+
"latent": { "name": "Latent" },
425+
"mask": { "name": "Mask" },
426+
"expand": { "name": "Expand" },
427+
"blur": { "name": "Blur" },
428+
"mask_only": { "name": "Mask Only" }
429+
},
430+
"outputs": {
431+
"0": { "name": "CONDITIONING" },
432+
"1": { "name": "LATENT" },
433+
"2": { "name": "MASK" }
434+
}
419435
}
420436
}

locales/fr/nodeDefs.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,5 +415,21 @@
415415
"2": { "name": "LARGEUR" },
416416
"3": { "name": "HAUTEUR" }
417417
}
418+
},
419+
"AILab_ReferenceLatentMask": {
420+
"display_name": "Kontext 参考潜伏遮罩 (RMBG) 🎭",
421+
"inputs": {
422+
"conditioning": { "name": "条件" },
423+
"latent": { "name": "latant" },
424+
"mask": { "name": "遮罩" },
425+
"expand": { "name": "扩展" },
426+
"blur": { "name": "模糊" },
427+
"mask_only": { "name": "仅遮罩" }
428+
},
429+
"outputs": {
430+
"0": { "name": "条件" },
431+
"1": { "name": "latant" },
432+
"2": { "name": "遮罩" }
433+
}
418434
}
419435
}

locales/ja/nodeDefs.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,5 +415,21 @@
415415
"2": { "name": "" },
416416
"3": { "name": "高さ" }
417417
}
418+
},
419+
"AILab_ReferenceLatentMask": {
420+
"display_name": "Kontext 参考潜伏マスク (RMBG) 🎭",
421+
"inputs": {
422+
"conditioning": { "name": "条件" },
423+
"latent": { "name": "潜伏" },
424+
"mask": { "name": "マスク" },
425+
"expand": { "name": "拡張" },
426+
"blur": { "name": "ぼかし" },
427+
"mask_only": { "name": "マスクのみ" }
428+
},
429+
"outputs": {
430+
"0": { "name": "条件" },
431+
"1": { "name": "潜伏" },
432+
"2": { "name": "マスク" }
433+
}
418434
}
419435
}

0 commit comments

Comments
 (0)