Skip to content

Commit d524ade

Browse files
authored
Add files via upload
1 parent a1c4dfb commit d524ade

File tree

6 files changed

+345
-296
lines changed

6 files changed

+345
-296
lines changed

AILab_BiRefNet.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,8 @@ def INPUT_TYPES(s):
350350
}
351351
}
352352

353-
RETURN_TYPES = ("IMAGE", "MASK")
354-
RETURN_NAMES = ("IMAGE", "MASK")
353+
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
354+
RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
355355
FUNCTION = "process_image"
356356
CATEGORY = "🧪AILab/🧽RMBG"
357357

@@ -447,7 +447,16 @@ def process_image(self, image, model, **params):
447447

448448
processed_masks.append(pil2tensor(mask))
449449

450-
return (torch.cat(processed_images, dim=0), torch.cat(processed_masks, dim=0))
450+
# Create mask image for visualization
451+
mask_images = []
452+
for mask_tensor in processed_masks:
453+
# Convert mask to RGB image format for visualization
454+
mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
455+
mask_images.append(mask_image)
456+
457+
mask_image_output = torch.cat(mask_images, dim=0)
458+
459+
return (torch.cat(processed_images, dim=0), torch.cat(processed_masks, dim=0), mask_image_output)
451460

452461
except Exception as e:
453462
handle_model_error(f"Error in image processing: {str(e)}")

AILab_BodySegment.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def INPUT_TYPES(cls):
8585
},
8686
}
8787

88-
RETURN_TYPES = ("IMAGE", "MASK")
89-
RETURN_NAMES = ("IMAGE", "MASK")
88+
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
89+
RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
9090
FUNCTION = "segment_body"
9191
CATEGORY = "🧪AILab/🧽RMBG"
9292

@@ -217,11 +217,20 @@ def segment_body(self, images, mask_blur=0, mask_offset=0, background_color="Alp
217217
batch_tensor.append(result_image)
218218
batch_masks.append(pil2tensor(mask_image))
219219

220+
# Create mask image for visualization
221+
mask_images = []
222+
for mask_tensor in batch_masks:
223+
# Convert mask to RGB image format for visualization
224+
mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
225+
mask_images.append(mask_image)
226+
227+
mask_image_output = torch.cat(mask_images, dim=0)
228+
220229
# Prepare final output
221230
batch_tensor = torch.cat(batch_tensor, dim=0)
222231
batch_masks = torch.cat(batch_masks, dim=0)
223232

224-
return (batch_tensor, batch_masks)
233+
return (batch_tensor, batch_masks, mask_image_output)
225234

226235
except Exception as e:
227236
self.clear_model()

AILab_ClothSegment.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def INPUT_TYPES(cls):
8585
},
8686
}
8787

88-
RETURN_TYPES = ("IMAGE", "MASK")
89-
RETURN_NAMES = ("IMAGE", "MASK")
88+
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
89+
RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
9090
FUNCTION = "segment_clothes"
9191
CATEGORY = "🧪AILab/🧽RMBG"
9292

@@ -253,11 +253,20 @@ def segment_clothes(self, images, process_res=1024, mask_blur=0, mask_offset=0,
253253
batch_tensor.append(result_image)
254254
batch_masks.append(pil2tensor(mask_image))
255255

256+
# Create mask image for visualization
257+
mask_images = []
258+
for mask_tensor in batch_masks:
259+
# Convert mask to RGB image format for visualization
260+
mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
261+
mask_images.append(mask_image)
262+
263+
mask_image_output = torch.cat(mask_images, dim=0)
264+
256265
# Prepare final output
257266
batch_tensor = torch.cat(batch_tensor, dim=0)
258267
batch_masks = torch.cat(batch_masks, dim=0)
259268

260-
return (batch_tensor, batch_masks)
269+
return (batch_tensor, batch_masks, mask_image_output)
261270

262271
except Exception as e:
263272
self.clear_model()

0 commit comments

Comments
 (0)