Skip to content

Commit 9a6dfcf

Browse files
authored
Add files via upload
1 parent ace618f commit 9a6dfcf

File tree

8 files changed

+247
-166
lines changed

8 files changed

+247
-166
lines changed

AILab_BiRefNet.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ def INPUT_TYPES(s):
331331
"model": "Select the BiRefNet model variant to use.",
332332
"mask_blur": "Specify the amount of blur to apply to the mask edges (0 for no blur, higher values for more blur).",
333333
"mask_offset": "Adjust the mask boundary (positive values expand the mask, negative values shrink it).",
334-
"background": "Choose the background color for the final output (Alpha for transparent background).",
335334
"invert_output": "Enable to invert both the image and mask output (useful for certain effects).",
336-
"refine_foreground": "Use Fast Foreground Colour Estimation to optimize transparent background"
335+
"refine_foreground": "Use Fast Foreground Colour Estimation to optimize transparent background",
336+
"background": "Choose background type: Alpha (transparent) or Color (custom background color).",
337+
"background_color": "Choose background color (Alpha = transparent)"
337338
}
338-
339339
return {
340340
"required": {
341341
"image": ("IMAGE", {"tooltip": tooltips["image"]}),
@@ -344,9 +344,10 @@ def INPUT_TYPES(s):
344344
"optional": {
345345
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
346346
"mask_offset": ("INT", {"default": 0, "min": -20, "max": 20, "step": 1, "tooltip": tooltips["mask_offset"]}),
347-
"background": (["Alpha", "black", "white", "gray", "green", "blue", "red"], {"default": "Alpha", "tooltip": tooltips["background"]}),
348347
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
349-
"refine_foreground": ("BOOLEAN", {"default": False, "tooltip": tooltips["refine_foreground"]})
348+
"refine_foreground": ("BOOLEAN", {"default": False, "tooltip": tooltips["refine_foreground"]}),
349+
"background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}),
350+
"background_color": ("COLOR", {"default": "#222222", "tooltip": tooltips["background_color"]}),
350351
}
351352
}
352353

@@ -358,35 +359,16 @@ def INPUT_TYPES(s):
358359
def process_image(self, image, model, **params):
359360
try:
360361
model_config = MODEL_CONFIG[model]
361-
362-
# Always use model's default resolution
363362
process_res = model_config.get("default_res", 1024)
364-
365-
# Handle special resolution requirements
366363
if model_config.get("force_res", False):
367364
base_res = 512
368365
process_res = ((process_res + base_res - 1) // base_res) * base_res
369366
else:
370367
process_res = process_res // 32 * 32
371-
372368
print(f"Using {model} model with {process_res} resolution")
373-
374369
params["process_res"] = process_res
375-
376370
processed_images = []
377371
processed_masks = []
378-
379-
bg_colors = {
380-
"Alpha": None,
381-
"black": (0, 0, 0),
382-
"white": (255, 255, 255),
383-
"gray": (128, 128, 128),
384-
"green": (0, 255, 0),
385-
"blue": (0, 0, 255),
386-
"red": (255, 0, 0)
387-
}
388-
389-
# Check and download model if needed
390372
cache_status, message = self.model.check_model_cache(model)
391373
if not cache_status:
392374
print(f"Cache check: {message}")
@@ -395,38 +377,24 @@ def process_image(self, image, model, **params):
395377
if not download_status:
396378
handle_model_error(download_message)
397379
print("Model files downloaded successfully")
398-
399-
# Load model if needed
400380
self.model.load_model(model)
401-
402381
for img in image:
403-
# Get mask from model
404382
mask = self.model.process_image(img, params)
405-
406-
# Post-process mask
407383
if params["mask_blur"] > 0:
408384
mask = mask.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"]))
409-
410385
if params["mask_offset"] != 0:
411386
if params["mask_offset"] > 0:
412387
for _ in range(params["mask_offset"]):
413388
mask = mask.filter(ImageFilter.MaxFilter(3))
414389
else:
415390
for _ in range(-params["mask_offset"]):
416391
mask = mask.filter(ImageFilter.MinFilter(3))
417-
418392
if params["invert_output"]:
419393
mask = Image.fromarray(255 - np.array(mask))
420-
421-
# Convert to tensors for refine_foreground
422394
img_tensor = torch.from_numpy(np.array(tensor2pil(img))).permute(2, 0, 1).unsqueeze(0) / 255.0
423395
mask_tensor = torch.from_numpy(np.array(mask)).unsqueeze(0).unsqueeze(0) / 255.0
424-
425396
if params.get("refine_foreground", False):
426-
refined_fg = refine_foreground(
427-
img_tensor,
428-
mask_tensor
429-
)
397+
refined_fg = refine_foreground(img_tensor, mask_tensor)
430398
refined_fg = tensor2pil(refined_fg[0].permute(1, 2, 0))
431399
orig_image = tensor2pil(img)
432400
r, g, b = refined_fg.split()
@@ -436,28 +404,30 @@ def process_image(self, image, model, **params):
436404
orig_rgba = orig_image.convert("RGBA")
437405
r, g, b, _ = orig_rgba.split()
438406
foreground = Image.merge('RGBA', (r, g, b, mask))
439-
440-
if params["background"] != "Alpha":
441-
bg_color = bg_colors[params["background"]]
442-
bg_image = Image.new('RGBA', orig_image.size, (*bg_color, 255))
407+
if params["background"] == "Alpha":
408+
processed_images.append(pil2tensor(foreground))
409+
else:
410+
def hex_to_rgba(hex_color):
411+
hex_color = hex_color.lstrip('#')
412+
if len(hex_color) == 6:
413+
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
414+
a = 255
415+
elif len(hex_color) == 8:
416+
r, g, b, a = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16), int(hex_color[6:8], 16)
417+
else:
418+
raise ValueError("Invalid color format")
419+
return (r, g, b, a)
420+
rgba = hex_to_rgba(params["background_color"])
421+
bg_image = Image.new('RGBA', orig_image.size, rgba)
443422
composite_image = Image.alpha_composite(bg_image, foreground)
444423
processed_images.append(pil2tensor(composite_image.convert("RGB")))
445-
else:
446-
processed_images.append(pil2tensor(foreground))
447-
448424
processed_masks.append(pil2tensor(mask))
449-
450-
# Create mask image for visualization
451425
mask_images = []
452426
for mask_tensor in processed_masks:
453-
# Convert mask to RGB image format for visualization
454427
mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
455428
mask_images.append(mask_image)
456-
457429
mask_image_output = torch.cat(mask_images, dim=0)
458-
459430
return (torch.cat(processed_images, dim=0), torch.cat(processed_masks, dim=0), mask_image_output)
460-
461431
except Exception as e:
462432
handle_model_error(f"Error in image processing: {str(e)}")
463433

AILab_BodySegment.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def INPUT_TYPES(cls):
6767
"process_res": "Processing resolution (fixed at 512x512)",
6868
"mask_blur": "Blur amount for mask edges",
6969
"mask_offset": "Expand/Shrink mask boundary",
70-
"background_color": "Choose background color (Alpha = transparent)",
7170
"invert_output": "Invert both image and mask output",
71+
"background": "Choose background type: Alpha (transparent) or Color (custom background color).",
72+
"background_color": "Choose background color (Alpha = transparent)"
7273
}
7374

7475
return {
@@ -80,8 +81,9 @@ def INPUT_TYPES(cls):
8081
for cls_name in available_classes},
8182
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
8283
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}),
83-
"background_color": (["Alpha", "black", "white", "gray", "green", "blue", "red"], {"default": "Alpha", "tooltip": tooltips["background_color"]}),
8484
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
85+
"background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}),
86+
"background_color": ("COLOR", {"default": "#222222", "tooltip": tooltips["background_color"]}),
8587
},
8688
}
8789

@@ -121,7 +123,7 @@ def download_model_files(self):
121123
except Exception as e:
122124
return False, f"Error downloading model file: {str(e)}"
123125

124-
def segment_body(self, images, mask_blur=0, mask_offset=0, background_color="Alpha", invert_output=False, **class_selections):
126+
def segment_body(self, images, mask_blur=0, mask_offset=0, background="Alpha", background_color="#222222", invert_output=False, **class_selections):
125127
try:
126128
# Check and download model if needed
127129
cache_status, message = self.check_model_cache()
@@ -196,21 +198,23 @@ def segment_body(self, images, mask_blur=0, mask_offset=0, background_color="Alp
196198
mask_image = Image.fromarray(255 - np.array(mask_image))
197199

198200
# Handle background color
199-
if background_color == "Alpha":
201+
if background == "Alpha":
200202
rgba_image = RGB2RGBA(orig_image, mask_image)
201203
result_image = pil2tensor(rgba_image)
202204
else:
203-
bg_colors = {
204-
"black": (0, 0, 0),
205-
"white": (255, 255, 255),
206-
"gray": (128, 128, 128),
207-
"green": (0, 255, 0),
208-
"blue": (0, 0, 255),
209-
"red": (255, 0, 0)
210-
}
211-
205+
def hex_to_rgba(hex_color):
206+
hex_color = hex_color.lstrip('#')
207+
if len(hex_color) == 6:
208+
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
209+
a = 255
210+
elif len(hex_color) == 8:
211+
r, g, b, a = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16), int(hex_color[6:8], 16)
212+
else:
213+
raise ValueError("Invalid color format")
214+
return (r, g, b, a)
212215
rgba_image = RGB2RGBA(orig_image, mask_image)
213-
bg_image = Image.new('RGBA', orig_image.size, (*bg_colors[background_color], 255))
216+
rgba = hex_to_rgba(background_color)
217+
bg_image = Image.new('RGBA', orig_image.size, rgba)
214218
composite_image = Image.alpha_composite(bg_image, rgba_image)
215219
result_image = pil2tensor(composite_image.convert('RGB'))
216220

AILab_ClothSegment.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ def INPUT_TYPES(cls):
6666
"process_res": "Processing resolution (higher = more VRAM)",
6767
"mask_blur": "Blur amount for mask edges",
6868
"mask_offset": "Expand/Shrink mask boundary",
69-
"background_color": "Choose background color (Alpha = transparent)",
7069
"invert_output": "Invert both image and mask output",
70+
"background": "Choose background type: Alpha (transparent) or Color (custom background color).",
71+
"background_color": "Choose background color (Alpha = transparent)"
7172
}
7273

7374
return {
@@ -80,8 +81,9 @@ def INPUT_TYPES(cls):
8081
"process_res": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 32, "tooltip": tooltips["process_res"]}),
8182
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
8283
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}),
83-
"background_color": (["Alpha", "black", "white", "gray", "green", "blue", "red"], {"default": "Alpha", "tooltip": tooltips["background_color"]}),
8484
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
85+
"background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}),
86+
"background_color": ("COLOR", {"default": "#222222", "tooltip": tooltips["background_color"]}),
8587
},
8688
}
8789

@@ -141,7 +143,7 @@ def download_model_files(self):
141143
except Exception as e:
142144
return False, f"Error downloading model files: {str(e)}"
143145

144-
def segment_clothes(self, images, process_res=1024, mask_blur=0, mask_offset=0, background_color="Alpha", invert_output=False, **class_selections):
146+
def segment_clothes(self, images, process_res=1024, mask_blur=0, mask_offset=0, background="Alpha", background_color="#222222", invert_output=False, **class_selections):
145147
try:
146148
# Check and download model if needed
147149
cache_status, message = self.check_model_cache()
@@ -232,21 +234,23 @@ def segment_clothes(self, images, process_res=1024, mask_blur=0, mask_offset=0,
232234
mask_image = Image.fromarray(255 - np.array(mask_image))
233235

234236
# Handle background color
235-
if background_color == "Alpha":
237+
if background == "Alpha":
236238
rgba_image = RGB2RGBA(orig_image, mask_image)
237239
result_image = pil2tensor(rgba_image)
238240
else:
239-
bg_colors = {
240-
"black": (0, 0, 0),
241-
"white": (255, 255, 255),
242-
"gray": (128, 128, 128),
243-
"green": (0, 255, 0),
244-
"blue": (0, 0, 255),
245-
"red": (255, 0, 0)
246-
}
247-
241+
def hex_to_rgba(hex_color):
242+
hex_color = hex_color.lstrip('#')
243+
if len(hex_color) == 6:
244+
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
245+
a = 255
246+
elif len(hex_color) == 8:
247+
r, g, b, a = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16), int(hex_color[6:8], 16)
248+
else:
249+
raise ValueError("Invalid color format")
250+
return (r, g, b, a)
248251
rgba_image = RGB2RGBA(orig_image, mask_image)
249-
bg_image = Image.new('RGBA', orig_image.size, (*bg_colors[background_color], 255))
252+
rgba = hex_to_rgba(background_color)
253+
bg_image = Image.new('RGBA', orig_image.size, rgba)
250254
composite_image = Image.alpha_composite(bg_image, rgba_image)
251255
result_image = pil2tensor(composite_image.convert('RGB'))
252256

0 commit comments

Comments
 (0)