Skip to content

Commit 8b9f8b0

Browse files
authored
Add files via upload
1 parent 75810f1 commit 8b9f8b0

File tree

2 files changed

+113
-149
lines changed

2 files changed

+113
-149
lines changed

AILab_BiRefNet.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-RMBG
1+
# ComfyUI-RMBG v1.9.2
22
# This custom node for ComfyUI provides functionality for background removal using BiRefNet models.
33
#
44
# Model License Notice:
@@ -16,6 +16,7 @@
1616
import sys
1717
import importlib.util
1818
from safetensors.torch import load_file
19+
import cv2
1920

2021
device = "cuda" if torch.cuda.is_available() else "cpu"
2122

@@ -161,26 +162,20 @@ def refine_foreground(image_bchw, masks_b1hw):
161162
refined_fg = []
162163
for i in range(b):
163164
mask = mask_np[i, 0]
164-
# Increase threshold for sharper edges
165-
thresh = 0.45 # Fine-tuned from 0.4
165+
thresh = 0.45
166166
mask_binary = (mask > thresh).astype(np.float32)
167167

168-
# Smaller kernel and sigma for more precise edge control
169168
edge_blur = cv2.GaussianBlur(mask_binary, (3, 3), 0)
169+
transition_mask = np.logical_and(mask > 0.05, mask < 0.95)
170170

171-
# Narrower transition area to reduce white edges
172-
transition_mask = np.logical_and(mask > 0.05, mask < 0.95) # Adjusted from 0.02-0.98
173-
174-
# Increase alpha for stronger original mask influence
175-
alpha = 0.85 # Increased from 0.7
171+
alpha = 0.85
176172
mask_refined = np.where(transition_mask,
177173
alpha * mask + (1-alpha) * edge_blur,
178174
mask_binary)
179175

180-
# Additional edge refinement
181176
edge_region = np.logical_and(mask > 0.2, mask < 0.8)
182177
mask_refined = np.where(edge_region,
183-
mask_refined * 0.98, # Slightly reduce intensity in edge regions
178+
mask_refined * 0.98,
184179
mask_refined)
185180

186181
result = []
@@ -333,7 +328,8 @@ def INPUT_TYPES(s):
333328
"mask_blur": "Specify the amount of blur to apply to the mask edges (0 for no blur, higher values for more blur).",
334329
"mask_offset": "Adjust the mask boundary (positive values expand the mask, negative values shrink it).",
335330
"background": "Choose the background color for the final output (Alpha for transparent background).",
336-
"invert_output": "Enable to invert both the image and mask output (useful for certain effects)."
331+
"invert_output": "Enable to invert both the image and mask output (useful for certain effects).",
332+
"refine_foreground": "Use Fast Foreground Colour Estimation to optimize transparent background"
337333
}
338334

339335
return {
@@ -345,7 +341,8 @@ def INPUT_TYPES(s):
345341
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
346342
"mask_offset": ("INT", {"default": 0, "min": -20, "max": 20, "step": 1, "tooltip": tooltips["mask_offset"]}),
347343
"background": (["Alpha", "black", "white", "gray", "green", "blue", "red"], {"default": "Alpha", "tooltip": tooltips["background"]}),
348-
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]})
344+
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
345+
"refine_foreground": ("BOOLEAN", {"default": False, "tooltip": tooltips["refine_foreground"]})
349346
}
350347
}
351348

@@ -417,11 +414,24 @@ def process_image(self, image, model, **params):
417414
if params["invert_output"]:
418415
mask = Image.fromarray(255 - np.array(mask))
419416

420-
# Create original image from tensor
421-
orig_image = tensor2pil(img)
422-
orig_rgba = orig_image.convert("RGBA")
423-
r, g, b, _ = orig_rgba.split()
424-
foreground = Image.merge('RGBA', (r, g, b, mask))
417+
# Convert to tensors for refine_foreground
418+
img_tensor = torch.from_numpy(np.array(tensor2pil(img))).permute(2, 0, 1).unsqueeze(0) / 255.0
419+
mask_tensor = torch.from_numpy(np.array(mask)).unsqueeze(0).unsqueeze(0) / 255.0
420+
421+
if params.get("refine_foreground", False):
422+
refined_fg = refine_foreground(
423+
img_tensor,
424+
mask_tensor
425+
)
426+
refined_fg = tensor2pil(refined_fg[0].permute(1, 2, 0))
427+
orig_image = tensor2pil(img)
428+
r, g, b = refined_fg.split()
429+
foreground = Image.merge('RGBA', (r, g, b, mask))
430+
else:
431+
orig_image = tensor2pil(img)
432+
orig_rgba = orig_image.convert("RGBA")
433+
r, g, b, _ = orig_rgba.split()
434+
foreground = Image.merge('RGBA', (r, g, b, mask))
425435

426436
if params["background"] != "Alpha":
427437
bg_color = bg_colors[params["background"]]

0 commit comments

Comments
 (0)