Skip to content

Commit 75dc006

Browse files
authored
Add files via upload
1 parent fe109d4 commit 75dc006

File tree

9 files changed

+747
-161
lines changed

9 files changed

+747
-161
lines changed

AILab_BiRefNet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ def hex_to_rgba(hex_color):
445445
else:
446446
raise ValueError("Invalid color format")
447447
return (r, g, b, a)
448-
rgba = hex_to_rgba(params["background_color"])
448+
background_color = params.get("background_color", "#222222")
449+
rgba = hex_to_rgba(background_color)
449450
bg_image = Image.new('RGBA', orig_image.size, rgba)
450451
composite_image = Image.alpha_composite(bg_image, foreground)
451452
processed_images.append(pil2tensor(composite_image.convert("RGB")))

AILab_BodySegment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def hex_to_rgba(hex_color):
213213
raise ValueError("Invalid color format")
214214
return (r, g, b, a)
215215
rgba_image = RGB2RGBA(orig_image, mask_image)
216+
background_color = params.get("background_color", "#222222")
216217
rgba = hex_to_rgba(background_color)
217218
bg_image = Image.new('RGBA', orig_image.size, rgba)
218219
composite_image = Image.alpha_composite(bg_image, rgba_image)

AILab_ClothSegment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def hex_to_rgba(hex_color):
249249
raise ValueError("Invalid color format")
250250
return (r, g, b, a)
251251
rgba_image = RGB2RGBA(orig_image, mask_image)
252+
background_color = params.get("background_color", "#222222")
252253
rgba = hex_to_rgba(background_color)
253254
bg_image = Image.new('RGBA', orig_image.size, rgba)
254255
composite_image = Image.alpha_composite(bg_image, rgba_image)

AILab_FaceSegment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def hex_to_rgba(hex_color):
254254
raise ValueError("Invalid color format")
255255
return (r, g, b, a)
256256
rgba_image = RGB2RGBA(orig_image, mask_image)
257+
background_color = params.get("background_color", "#222222")
257258
rgba = hex_to_rgba(background_color)
258259
bg_image = Image.new('RGBA', orig_image.size, rgba)
259260
composite_image = Image.alpha_composite(bg_image, rgba_image)

AILab_FashionSegment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def hex_to_rgba(hex_color):
333333
raise ValueError("Invalid color format")
334334
return (r, g, b, a)
335335
rgba_image = RGB2RGBA(orig_image, mask_image)
336+
background_color = params.get("background_color", "#222222")
336337
rgba = hex_to_rgba(background_color)
337338
bg_image = Image.new('RGBA', orig_image.size, rgba)
338339
composite_image = Image.alpha_composite(bg_image, rgba_image)

AILab_ImageMaskTools.py

Lines changed: 595 additions & 71 deletions
Large diffs are not rendered by default.

AILab_RMBG.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,8 @@ def hex_to_rgba(hex_color):
633633
else:
634634
raise ValueError("Invalid color format")
635635
return (r, g, b, a)
636-
rgba = hex_to_rgba(params["background_color"])
636+
background_color = params.get("background_color", "#222222")
637+
rgba = hex_to_rgba(background_color)
637638
bg_image = Image.new('RGBA', orig_image.size, rgba)
638639
composite_image = Image.alpha_composite(bg_image, foreground)
639640
processed_images.append(pil2tensor(composite_image.convert("RGB")))

AILab_Segment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def __init__(self):
188188
self.clean_state_dict = clean_state_dict
189189
self.SLConfig = SLConfig
190190
self.build_model = build_model
191+
self._sam_model_cache = {}
192+
self._dino_model_cache = {}
191193

192194
def segment(self, image, prompt, sam_model, dino_model, threshold=0.35,
193195
mask_blur=0, mask_offset=0, background="Alpha",
@@ -241,6 +243,8 @@ def segment(self, image, prompt, sam_model, dino_model, threshold=0.35,
241243
return (pil2tensor(result_image), mask_tensor, mask_image_output)
242244

243245
def load_sam(self, model_name):
246+
if model_name in self._sam_model_cache:
247+
return self._sam_model_cache[model_name]
244248
sam_checkpoint_path = self.get_local_filepath(
245249
SAM_MODELS[model_name]["model_url"], "sam")
246250
model_type = SAM_MODELS[model_name]["model_type"]
@@ -252,9 +256,12 @@ def load_sam(self, model_name):
252256
sam_device = comfy.model_management.get_torch_device()
253257
sam.to(device=sam_device)
254258
sam.eval()
259+
self._sam_model_cache[model_name] = sam
255260
return sam
256261

257262
def load_groundingdino(self, model_name):
263+
if model_name in self._dino_model_cache:
264+
return self._dino_model_cache[model_name]
258265
import sys
259266
from io import StringIO
260267
temp_stdout = StringIO()
@@ -279,6 +286,7 @@ def load_groundingdino(self, model_name):
279286
device = comfy.model_management.get_torch_device()
280287
dino.to(device=device)
281288
dino.eval()
289+
self._dino_model_cache[model_name] = dino
282290
return dino
283291
finally:
284292
output = temp_stdout.getvalue()

AILab_SegmentV2.py

Lines changed: 136 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,14 @@ def apply_background_color(image: Image.Image, mask_image: Image.Image,
9797
background_color: str = "#222222") -> Image.Image:
9898
rgba_image = image.copy().convert('RGBA')
9999
rgba_image.putalpha(mask_image.convert('L'))
100+
100101
if background == "Color":
101102
def hex_to_rgba(hex_color):
102103
hex_color = hex_color.lstrip('#')
103104
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
104105
return (r, g, b, 255)
106+
params = {"background_color": background_color}
107+
background_color = params.get("background_color", "#222222")
105108
rgba = hex_to_rgba(background_color)
106109
bg_image = Image.new('RGBA', image.size, rgba)
107110
composite_image = Image.alpha_composite(bg_image, rgba_image)
@@ -146,7 +149,7 @@ def INPUT_TYPES(cls):
146149
"dino_model": (list(DINO_MODELS.keys()),),
147150
},
148151
"optional": {
149-
"threshold": ("FLOAT", {"default": 0.30, "min": 0.05, "max": 0.95, "step": 0.01, "tooltip": tooltips["threshold"]}),
152+
"threshold": ("FLOAT", {"default": 0.35, "min": 0.05, "max": 0.95, "step": 0.01, "tooltip": tooltips["threshold"]}),
150153
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
151154
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}),
152155
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
@@ -167,105 +170,150 @@ def __init__(self):
167170
def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
168171
mask_blur=0, mask_offset=0, background="Alpha",
169172
background_color="#222222", invert_output=False):
170-
img_pil = tensor2pil(image[0]) if image.ndim == 4 else tensor2pil(image)
171-
img_np = np.array(img_pil.convert("RGB"))
172173
device = "cuda" if torch.cuda.is_available() else "cpu"
173174

174-
# Load GroundingDINO config and weights
175-
dino_info = DINO_MODELS[dino_model]
176-
config_path = get_or_download_model_file(dino_info["config_filename"], dino_info["config_url"], "grounding-dino")
177-
weights_path = get_or_download_model_file(dino_info["model_filename"], dino_info["model_url"], "grounding-dino")
175+
# 处理批量图像
176+
batch_size = image.shape[0] if len(image.shape) == 4 else 1
177+
if len(image.shape) == 3:
178+
image = image.unsqueeze(0)
179+
180+
result_images = []
181+
result_masks = []
182+
result_mask_images = []
183+
184+
for b in range(batch_size):
185+
img_pil = tensor2pil(image[b])
186+
img_np = np.array(img_pil.convert("RGB"))
178187

179-
# Load and cache GroundingDINO model
180-
dino_key = (config_path, weights_path, device)
181-
if dino_key not in self.dino_model_cache:
182-
args = SLConfig.fromfile(config_path)
183-
model = build_model(args)
184-
checkpoint = torch.load(weights_path, map_location="cpu")
185-
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
186-
model.eval()
187-
model.to(device)
188-
self.dino_model_cache[dino_key] = model
189-
dino = self.dino_model_cache[dino_key]
188+
# Load GroundingDINO config and weights
189+
dino_info = DINO_MODELS[dino_model]
190+
config_path = get_or_download_model_file(dino_info["config_filename"], dino_info["config_url"], "grounding-dino")
191+
weights_path = get_or_download_model_file(dino_info["model_filename"], dino_info["model_url"], "grounding-dino")
190192

191-
# Preprocess image for DINO
192-
from groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
193-
transform = Compose([
194-
RandomResize([800], max_size=1333),
195-
ToTensor(),
196-
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
197-
])
198-
image_tensor, _ = transform(img_pil.convert("RGB"), None)
199-
image_tensor = image_tensor.unsqueeze(0).to(device)
193+
# Load and cache GroundingDINO model
194+
dino_key = (config_path, weights_path, device)
195+
if dino_key not in self.dino_model_cache:
196+
args = SLConfig.fromfile(config_path)
197+
model = build_model(args)
198+
checkpoint = torch.load(weights_path, map_location="cpu")
199+
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
200+
model.eval()
201+
model.to(device)
202+
self.dino_model_cache[dino_key] = model
203+
dino = self.dino_model_cache[dino_key]
200204

201-
# Prepare text prompt
202-
text_prompt = prompt if prompt.endswith(".") else prompt + "."
205+
# Download/check SAM weights
206+
sam_info = SAM_MODELS[sam_model]
207+
sam_ckpt_path = get_or_download_model_file(sam_info["filename"], sam_info["model_url"], "SAM")
203208

204-
# Forward pass
205-
with torch.no_grad():
206-
outputs = dino(image_tensor, captions=[text_prompt])
207-
logits = outputs["pred_logits"].sigmoid()[0]
208-
boxes = outputs["pred_boxes"][0]
209+
# Load SAM model (cache to avoid reloading)
210+
sam_key = (sam_info["model_type"], sam_ckpt_path, device)
211+
if sam_key not in self.sam_model_cache:
212+
try:
213+
sam = sam_model_registry[sam_info["model_type"]](checkpoint=sam_ckpt_path)
214+
sam.to(device)
215+
self.sam_model_cache[sam_key] = SamPredictor(sam)
216+
except RuntimeError as e:
217+
if "Unexpected key(s) in state_dict" in str(e):
218+
print("Warning: SAM model loading issue detected, please try using SegmentV1 node instead")
219+
print(f"Error details: {str(e)}")
220+
width, height = img_pil.size
221+
empty_mask = torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
222+
empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
223+
result_image = apply_background_color(img_pil, Image.fromarray((empty_mask[0].numpy() * 255).astype(np.uint8)), background, background_color)
224+
result_images.append(pil2tensor(result_image))
225+
result_masks.append(empty_mask)
226+
result_mask_images.append(empty_mask_rgb)
227+
continue
228+
else:
229+
raise e
230+
predictor = self.sam_model_cache[sam_key]
209231

210-
# Filter boxes by threshold
211-
filt_mask = logits.max(dim=1)[0] > threshold
212-
boxes_filt = boxes[filt_mask]
213-
if boxes_filt.shape[0] == 0:
214-
width, height = img_pil.size
215-
empty_mask = torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
216-
empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
217-
result_image = apply_background_color(img_pil, Image.fromarray((empty_mask[0].numpy() * 255).astype(np.uint8)), background, background_color)
218-
return (pil2tensor(result_image), empty_mask, empty_mask_rgb)
232+
# Preprocess image for DINO
233+
from groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
234+
transform = Compose([
235+
RandomResize([800], max_size=1333),
236+
ToTensor(),
237+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
238+
])
239+
image_tensor, _ = transform(img_pil.convert("RGB"), None)
240+
image_tensor = image_tensor.unsqueeze(0).to(device)
241+
242+
# Prepare text prompt
243+
text_prompt = prompt if prompt.endswith(".") else prompt + "."
244+
245+
# Forward pass
246+
with torch.no_grad():
247+
outputs = dino(image_tensor, captions=[text_prompt])
248+
logits = outputs["pred_logits"].sigmoid()[0]
249+
boxes = outputs["pred_boxes"][0]
219250

220-
# Convert boxes to xyxy
221-
H, W = img_pil.size[1], img_pil.size[0]
222-
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes_filt)
223-
boxes_xyxy = boxes_xyxy * torch.tensor([W, H, W, H], dtype=torch.float32, device=boxes_xyxy.device)
224-
boxes_xyxy = boxes_xyxy.cpu().numpy()
251+
# Filter boxes by threshold
252+
filt_mask = logits.max(dim=1)[0] > threshold
253+
boxes_filt = boxes[filt_mask]
254+
255+
# Handle case with no detected boxes
256+
if boxes_filt.shape[0] == 0:
257+
width, height = img_pil.size
258+
empty_mask = torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
259+
empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
260+
result_image = apply_background_color(img_pil, Image.fromarray((empty_mask[0].numpy() * 255).astype(np.uint8)), background, background_color)
261+
result_images.append(pil2tensor(result_image))
262+
result_masks.append(empty_mask)
263+
result_mask_images.append(empty_mask_rgb)
264+
continue
225265

226-
# Download/check SAM weights
227-
sam_info = SAM_MODELS[sam_model]
228-
sam_ckpt_path = get_or_download_model_file(sam_info["filename"], sam_info["model_url"], "SAM")
266+
# Convert boxes to xyxy
267+
H, W = img_pil.size[1], img_pil.size[0]
268+
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes_filt)
269+
boxes_xyxy = boxes_xyxy * torch.tensor([W, H, W, H], dtype=torch.float32, device=boxes_xyxy.device)
270+
boxes_xyxy = boxes_xyxy.cpu().numpy()
229271

230-
# Load SAM model (cache to avoid reloading)
231-
sam_key = (sam_info["model_type"], sam_ckpt_path, device)
232-
if sam_key not in self.sam_model_cache:
233-
sam = sam_model_registry[sam_info["model_type"]](checkpoint=sam_ckpt_path)
234-
sam.to(device)
235-
self.sam_model_cache[sam_key] = SamPredictor(sam)
236-
predictor = self.sam_model_cache[sam_key]
272+
# Use SAM to get masks for each box
273+
predictor.set_image(img_np)
274+
boxes_tensor = torch.tensor(boxes_xyxy, dtype=torch.float32, device=predictor.device)
275+
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_tensor, img_np.shape[:2])
276+
masks, _, _ = predictor.predict_torch(
277+
point_coords=None,
278+
point_labels=None,
279+
boxes=transformed_boxes,
280+
multimask_output=False
281+
)
237282

238-
# Use SAM to get masks for each box
239-
predictor.set_image(img_np)
240-
boxes_tensor = torch.tensor(boxes_xyxy, dtype=torch.float32, device=predictor.device)
241-
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_tensor, img_np.shape[:2])
242-
masks, _, _ = predictor.predict_torch(
243-
point_coords=None,
244-
point_labels=None,
245-
boxes=transformed_boxes,
246-
multimask_output=False
247-
)
248-
# Process mask following the original implementation
249-
print(f"Mask shape before processing: {masks.shape}")
250-
# Combine all masks into one
251-
combined_mask = torch.max(masks, dim=0)[0] # Take maximum across all masks
252-
mask = combined_mask.float().cpu().numpy()
253-
print(f"Mask shape after processing: {mask.shape}")
254-
# Squeeze out the extra dimension to get a 2D array
255-
mask = mask.squeeze(0)
256-
print(f"Final mask shape: {mask.shape}")
257-
mask = (mask * 255).astype(np.uint8)
258-
mask_pil = Image.fromarray(mask, mode="L")
283+
# Combine all masks into one
284+
combined_mask = torch.max(masks, dim=0)[0] # Take maximum across all masks
285+
mask = combined_mask.float().cpu().numpy()
286+
mask = mask.squeeze(0)
287+
mask = (mask * 255).astype(np.uint8)
288+
mask_pil = Image.fromarray(mask, mode="L")
259289

260-
mask_image = process_mask(mask_pil, invert_output, mask_blur, mask_offset)
261-
result_image = apply_background_color(img_pil, mask_image, background, background_color)
262-
if background == "Color":
263-
result_image = result_image.convert("RGB")
264-
else:
265-
result_image = result_image.convert("RGBA")
266-
mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
267-
mask_image_vis = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
268-
return (pil2tensor(result_image), mask_tensor, mask_image_vis)
290+
# Process mask and apply background
291+
mask_image = process_mask(mask_pil, invert_output, mask_blur, mask_offset)
292+
result_image = apply_background_color(img_pil, mask_image, background, background_color)
293+
if background == "Color":
294+
result_image = result_image.convert("RGB")
295+
else:
296+
result_image = result_image.convert("RGBA")
297+
298+
# Convert to tensors
299+
mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
300+
mask_image_vis = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
301+
302+
result_images.append(pil2tensor(result_image))
303+
result_masks.append(mask_tensor)
304+
result_mask_images.append(mask_image_vis)
305+
306+
# 如果没有成功处理任何图像,返回空结果
307+
if len(result_images) == 0:
308+
width, height = tensor2pil(image[0]).size
309+
empty_mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32, device="cpu")
310+
empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
311+
return (image, empty_mask, empty_mask_rgb)
312+
313+
# 合并所有批次的结果
314+
return (torch.cat(result_images, dim=0),
315+
torch.cat(result_masks, dim=0),
316+
torch.cat(result_mask_images, dim=0))
269317

270318
NODE_CLASS_MAPPINGS = {
271319
"SegmentV2": SegmentV2,

0 commit comments

Comments
 (0)