Skip to content

Commit f4ecb32

Browse files
authored
Add files via upload
1 parent c36a6a8 commit f4ecb32

File tree

2 files changed

+5
-43
lines changed

2 files changed

+5
-43
lines changed

AILab_FashionSegment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,10 @@ def download_model_files(self):
236236
except Exception as e:
237237
return False, f"Error downloading model files: {str(e)}"
238238

239-
def segment_fashion(self, images, accessories_options, process_res=512, mask_blur=0, mask_offset=0,
239+
def segment_fashion(self, images, accessories_options=None, process_res=512, mask_blur=0, mask_offset=0,
240240
background="Alpha", background_color="#222222", invert_output=False, **class_selections):
241+
if accessories_options is None:
242+
accessories_options = []
241243
try:
242244
# Check and download model
243245
cache_status, message = self.check_model_cache()
@@ -360,7 +362,7 @@ def hex_to_rgba(hex_color):
360362
self.clear_model()
361363
raise RuntimeError(f"Error in fashion segmentation: {str(e)}")
362364
finally:
363-
if not self.model.training:
365+
if self.model is not None and not self.model.training:
364366
self.clear_model()
365367

366368
def __del__(self):

AILab_SegmentV2.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from AILab_ImageMaskTools import pil2tensor, tensor2pil
1818

19-
# SAM model definitions (6 models)
2019
SAM_MODELS = {
2120
"sam_vit_h (2.56GB)": {
2221
"model_url": "https://huggingface.co/1038lab/sam/resolve/main/sam_vit_h.pth",
@@ -50,7 +49,6 @@
5049
}
5150
}
5251

53-
# GroundingDINO model definitions (2 models)
5452
DINO_MODELS = {
5553
"GroundingDINO_SwinT_OGC (694MB)": {
5654
"config_url": "https://huggingface.co/1038lab/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
@@ -97,7 +95,6 @@ def apply_background_color(image: Image.Image, mask_image: Image.Image,
9795
background_color: str = "#222222") -> Image.Image:
9896
rgba_image = image.copy().convert('RGBA')
9997
rgba_image.putalpha(mask_image.convert('L'))
100-
10198
if background == "Color":
10299
def hex_to_rgba(hex_color):
103100
hex_color = hex_color.lstrip('#')
@@ -172,25 +169,18 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
172169
background_color="#222222", invert_output=False):
173170
device = "cuda" if torch.cuda.is_available() else "cpu"
174171

175-
# 处理批量图像
176172
batch_size = image.shape[0] if len(image.shape) == 4 else 1
177173
if len(image.shape) == 3:
178174
image = image.unsqueeze(0)
179-
180175
result_images = []
181176
result_masks = []
182177
result_mask_images = []
183-
184178
for b in range(batch_size):
185179
img_pil = tensor2pil(image[b])
186180
img_np = np.array(img_pil.convert("RGB"))
187-
188-
# Load GroundingDINO config and weights
189181
dino_info = DINO_MODELS[dino_model]
190182
config_path = get_or_download_model_file(dino_info["config_filename"], dino_info["config_url"], "grounding-dino")
191183
weights_path = get_or_download_model_file(dino_info["model_filename"], dino_info["model_url"], "grounding-dino")
192-
193-
# Load and cache GroundingDINO model
194184
dino_key = (config_path, weights_path, device)
195185
if dino_key not in self.dino_model_cache:
196186
args = SLConfig.fromfile(config_path)
@@ -201,12 +191,8 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
201191
model.to(device)
202192
self.dino_model_cache[dino_key] = model
203193
dino = self.dino_model_cache[dino_key]
204-
205-
# Download/check SAM weights
206194
sam_info = SAM_MODELS[sam_model]
207195
sam_ckpt_path = get_or_download_model_file(sam_info["filename"], sam_info["model_url"], "SAM")
208-
209-
# Load SAM model (cache to avoid reloading)
210196
sam_key = (sam_info["model_type"], sam_ckpt_path, device)
211197
if sam_key not in self.sam_model_cache:
212198
try:
@@ -230,8 +216,6 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
230216
else:
231217
raise e
232218
predictor = self.sam_model_cache[sam_key]
233-
234-
# Preprocess image for DINO
235219
from groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
236220
transform = Compose([
237221
RandomResize([800], max_size=1333),
@@ -240,21 +224,13 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
240224
])
241225
image_tensor, _ = transform(img_pil.convert("RGB"), None)
242226
image_tensor = image_tensor.unsqueeze(0).to(device)
243-
244-
# Prepare text prompt
245227
text_prompt = prompt if prompt.endswith(".") else prompt + "."
246-
247-
# Forward pass
248228
with torch.no_grad():
249229
outputs = dino(image_tensor, captions=[text_prompt])
250230
logits = outputs["pred_logits"].sigmoid()[0]
251231
boxes = outputs["pred_boxes"][0]
252-
253-
# Filter boxes by threshold
254232
filt_mask = logits.max(dim=1)[0] > threshold
255233
boxes_filt = boxes[filt_mask]
256-
257-
# Handle case with no detected boxes
258234
if boxes_filt.shape[0] == 0:
259235
width, height = img_pil.size
260236
empty_mask = torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
@@ -264,14 +240,10 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
264240
result_masks.append(empty_mask)
265241
result_mask_images.append(empty_mask_rgb)
266242
continue
267-
268-
# Convert boxes to xyxy
269243
H, W = img_pil.size[1], img_pil.size[0]
270244
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes_filt)
271245
boxes_xyxy = boxes_xyxy * torch.tensor([W, H, W, H], dtype=torch.float32, device=boxes_xyxy.device)
272246
boxes_xyxy = boxes_xyxy.cpu().numpy()
273-
274-
# Use SAM to get masks for each box
275247
predictor.set_image(img_np)
276248
boxes_tensor = torch.tensor(boxes_xyxy, dtype=torch.float32, device=predictor.device)
277249
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_tensor, img_np.shape[:2])
@@ -281,38 +253,27 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
281253
boxes=transformed_boxes,
282254
multimask_output=False
283255
)
284-
285-
# Combine all masks into one
286-
combined_mask = torch.max(masks, dim=0)[0] # Take maximum across all masks
256+
combined_mask = torch.max(masks, dim=0)[0]
287257
mask = combined_mask.float().cpu().numpy()
288258
mask = mask.squeeze(0)
289259
mask = (mask * 255).astype(np.uint8)
290260
mask_pil = Image.fromarray(mask, mode="L")
291-
292-
# Process mask and apply background
293261
mask_image = process_mask(mask_pil, invert_output, mask_blur, mask_offset)
294262
result_image = apply_background_color(img_pil, mask_image, background, background_color)
295263
if background == "Color":
296264
result_image = result_image.convert("RGB")
297265
else:
298266
result_image = result_image.convert("RGBA")
299-
300-
# Convert to tensors
301267
mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
302268
mask_image_vis = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
303-
304269
result_images.append(pil2tensor(result_image))
305270
result_masks.append(mask_tensor)
306271
result_mask_images.append(mask_image_vis)
307-
308-
# 如果没有成功处理任何图像,返回空结果
309272
if len(result_images) == 0:
310273
width, height = tensor2pil(image[0]).size
311274
empty_mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32, device="cpu")
312275
empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
313276
return (image, empty_mask, empty_mask_rgb)
314-
315-
# 合并所有批次的结果
316277
return (torch.cat(result_images, dim=0),
317278
torch.cat(result_masks, dim=0),
318279
torch.cat(result_mask_images, dim=0))
@@ -324,4 +285,3 @@ def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
324285
NODE_DISPLAY_NAME_MAPPINGS = {
325286
"SegmentV2": "Segmentation V2 (RMBG)",
326287
}
327-

0 commit comments

Comments
 (0)