Skip to content

Commit 320c566

Browse files
authored
Add files via upload
1 parent 5375016 commit 320c566

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

AILab_SAM3Segment.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def INPUT_TYPES(cls):
119119
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}),
120120
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}),
121121
"invert_output": ("BOOLEAN", {"default": False}),
122+
"unload_model": ("BOOLEAN", {"default": False}),
122123
"background": (["Alpha", "Color"], {"default": "Alpha"}),
123124
"background_color": ("COLORCODE", {"default": "#222222"}),
124125
},
@@ -190,14 +191,18 @@ def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask
190191
mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
191192
return result_image, mask_tensor, mask_rgb
192193

193-
def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, background="Alpha", background_color="#222222"):
194+
def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"):
195+
194196
if image.ndim == 3:
195197
image = image.unsqueeze(0)
198+
196199
processor, torch_device = self._load_processor(sam3_model, device)
197200
autocast_device = comfy.model_management.get_autocast_device(torch_device)
198201
autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device)
199202
ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext()
203+
200204
result_images, result_masks, result_mask_images = [], [], []
205+
201206
with ctx:
202207
for tensor_img in image:
203208
img_pil, mask_tensor, mask_rgb = self._run_single(
@@ -214,6 +219,15 @@ def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, m
214219
result_images.append(pil2tensor(img_pil))
215220
result_masks.append(mask_tensor)
216221
result_mask_images.append(mask_rgb)
222+
223+
if unload_model:
224+
device_str = "cuda" if torch_device.type == "cuda" else "cpu"
225+
cache_key = (sam3_model, device_str)
226+
if cache_key in self.processor_cache:
227+
del self.processor_cache[cache_key]
228+
if torch_device.type == "cuda":
229+
torch.cuda.empty_cache()
230+
217231
return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0)
218232

219233

@@ -223,4 +237,4 @@ def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, m
223237

224238
NODE_DISPLAY_NAME_MAPPINGS = {
225239
"SAM3Segment": "SAM3 Segmentation (RMBG)",
226-
}
240+
}

0 commit comments

Comments
 (0)