Skip to content

Commit 8ad43b4

Browse files
authored
Add files via upload
1 parent ff81bb2 commit 8ad43b4

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed

AILab_SAM3Segment.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import os
2+
import sys
3+
from contextlib import nullcontext
4+
5+
import numpy as np
6+
import torch
7+
from PIL import Image, ImageFilter
8+
from torch.hub import download_url_to_file
9+
10+
import folder_paths
11+
import comfy.model_management
12+
13+
from AILab_ImageMaskTools import pil2tensor, tensor2pil
14+
15+
CURRENT_DIR = os.path.dirname(__file__)
16+
SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3")
17+
if SAM3_LOCAL_DIR not in sys.path:
18+
sys.path.insert(0, SAM3_LOCAL_DIR)
19+
20+
SAM3_BPE_PATH = os.path.join(SAM3_LOCAL_DIR, "assets", "bpe_simple_vocab_16e6.txt.gz")
21+
if not os.path.isfile(SAM3_BPE_PATH):
22+
raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.")
23+
24+
from sam3.model_builder import build_sam3_image_model # noqa: E402
25+
from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402
26+
27+
_DEFAULT_PT_ENTRY = {
28+
"model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt",
29+
"filename": "sam3.pt",
30+
}
31+
32+
SAM3_MODELS = {
33+
"sam3": _DEFAULT_PT_ENTRY.copy(),
34+
}
35+
36+
37+
def get_sam3_pt_models():
38+
"""Return a dictionary containing the PT model definition."""
39+
entry = SAM3_MODELS.get("sam3")
40+
if entry and entry.get("filename", "").endswith(".pt"):
41+
return {"sam3": entry}
42+
# Fallback: upgrade any legacy entry to PT naming
43+
for key, value in SAM3_MODELS.items():
44+
if value.get("filename", "").endswith(".pt"):
45+
return {"sam3": value}
46+
if "sam3" in key and value:
47+
candidate = value.copy()
48+
candidate["model_url"] = _DEFAULT_PT_ENTRY["model_url"]
49+
candidate["filename"] = _DEFAULT_PT_ENTRY["filename"]
50+
return {"sam3": candidate}
51+
return {"sam3": _DEFAULT_PT_ENTRY.copy()}
52+
53+
54+
def process_mask(mask_image, invert_output=False, mask_blur=0, mask_offset=0):
55+
if invert_output:
56+
mask_np = np.array(mask_image)
57+
mask_image = Image.fromarray(255 - mask_np)
58+
if mask_blur > 0:
59+
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
60+
if mask_offset != 0:
61+
filt = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter
62+
size = abs(mask_offset) * 2 + 1
63+
for _ in range(abs(mask_offset)):
64+
mask_image = mask_image.filter(filt(size))
65+
return mask_image
66+
67+
68+
def apply_background_color(image, mask_image, background="Alpha", background_color="#222222"):
69+
rgba_image = image.copy().convert("RGBA")
70+
rgba_image.putalpha(mask_image.convert("L"))
71+
if background == "Color":
72+
hex_color = background_color.lstrip("#")
73+
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
74+
bg_image = Image.new("RGBA", image.size, (r, g, b, 255))
75+
composite = Image.alpha_composite(bg_image, rgba_image)
76+
return composite.convert("RGB")
77+
return rgba_image
78+
79+
80+
def get_or_download_model_file(filename, url):
81+
local_path = None
82+
if hasattr(folder_paths, "get_full_path"):
83+
local_path = folder_paths.get_full_path("sam3", filename)
84+
if local_path and os.path.isfile(local_path):
85+
return local_path
86+
base_models_dir = getattr(folder_paths, "models_dir", os.path.join(CURRENT_DIR, "models"))
87+
models_dir = os.path.join(base_models_dir, "sam3")
88+
os.makedirs(models_dir, exist_ok=True)
89+
local_path = os.path.join(models_dir, filename)
90+
if not os.path.exists(local_path):
91+
print(f"Downloading {filename} from {url} ...")
92+
download_url_to_file(url, local_path)
93+
return local_path
94+
95+
96+
def _resolve_device(user_choice):
97+
auto_device = comfy.model_management.get_torch_device()
98+
if user_choice == "CPU":
99+
return torch.device("cpu")
100+
if user_choice == "GPU":
101+
if auto_device.type != "cuda":
102+
raise RuntimeError("GPU unavailable")
103+
return torch.device("cuda")
104+
return auto_device
105+
106+
107+
class SAM3Segment:
108+
@classmethod
109+
def INPUT_TYPES(cls):
110+
return {
111+
"required": {
112+
"image": ("IMAGE",),
113+
"prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Describe the concept"}),
114+
"sam3_model": (list(SAM3_MODELS.keys()), {"default": "sam3"}),
115+
"device": (["Auto", "CPU", "GPU"], {"default": "Auto"}),
116+
"confidence_threshold": ("FLOAT", {"default": 0.5, "min": 0.05, "max": 0.95, "step": 0.01}),
117+
},
118+
"optional": {
119+
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}),
120+
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}),
121+
"invert_output": ("BOOLEAN", {"default": False}),
122+
"background": (["Alpha", "Color"], {"default": "Alpha"}),
123+
"background_color": ("COLORCODE", {"default": "#222222"}),
124+
},
125+
}
126+
127+
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
128+
RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
129+
FUNCTION = "segment"
130+
CATEGORY = "🧪AILab/🧽RMBG"
131+
132+
def __init__(self):
133+
self.processor_cache = {}
134+
135+
def _load_processor(self, model_choice, device_choice):
136+
torch_device = _resolve_device(device_choice)
137+
device_str = "cuda" if torch_device.type == "cuda" else "cpu"
138+
cache_key = (model_choice, device_str)
139+
if cache_key not in self.processor_cache:
140+
model_info = SAM3_MODELS[model_choice]
141+
ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"])
142+
model = build_sam3_image_model(
143+
bpe_path=SAM3_BPE_PATH,
144+
device=device_str,
145+
eval_mode=True,
146+
checkpoint_path=ckpt_path,
147+
load_from_HF=False,
148+
enable_segmentation=True,
149+
enable_inst_interactivity=False,
150+
)
151+
processor = Sam3Processor(model, device=device_str)
152+
self.processor_cache[cache_key] = processor
153+
return self.processor_cache[cache_key], torch_device
154+
155+
def _empty_result(self, img_pil, background, background_color):
156+
w, h = img_pil.size
157+
mask_image = Image.new("L", (w, h), 0)
158+
result_image = apply_background_color(img_pil, mask_image, background, background_color)
159+
if background == "Alpha":
160+
result_image = result_image.convert("RGBA")
161+
else:
162+
result_image = result_image.convert("RGB")
163+
empty_mask = torch.zeros((1, h, w), dtype=torch.float32)
164+
mask_rgb = empty_mask.reshape((-1, 1, h, w)).movedim(1, -1).expand(-1, -1, -1, 3)
165+
return result_image, empty_mask, mask_rgb
166+
167+
def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask_offset, invert, background, background_color):
168+
img_pil = tensor2pil(img_tensor)
169+
text = prompt.strip() or "object"
170+
state = processor.set_image(img_pil)
171+
processor.reset_all_prompts(state)
172+
processor.set_confidence_threshold(confidence, state)
173+
state = processor.set_text_prompt(text, state)
174+
masks = state.get("masks")
175+
if masks is None or masks.numel() == 0:
176+
return self._empty_result(img_pil, background, background_color)
177+
masks = masks.float().to("cpu")
178+
if masks.ndim == 4:
179+
masks = masks.squeeze(1)
180+
combined = masks.amax(dim=0)
181+
mask_np = (combined.clamp(0, 1).numpy() * 255).astype(np.uint8)
182+
mask_image = Image.fromarray(mask_np, mode="L")
183+
mask_image = process_mask(mask_image, invert, mask_blur, mask_offset)
184+
result_image = apply_background_color(img_pil, mask_image, background, background_color)
185+
if background == "Alpha":
186+
result_image = result_image.convert("RGBA")
187+
else:
188+
result_image = result_image.convert("RGB")
189+
mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
190+
mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
191+
return result_image, mask_tensor, mask_rgb
192+
193+
def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, background="Alpha", background_color="#222222"):
194+
if image.ndim == 3:
195+
image = image.unsqueeze(0)
196+
processor, torch_device = self._load_processor(sam3_model, device)
197+
autocast_device = comfy.model_management.get_autocast_device(torch_device)
198+
autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device)
199+
ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext()
200+
result_images, result_masks, result_mask_images = [], [], []
201+
with ctx:
202+
for tensor_img in image:
203+
img_pil, mask_tensor, mask_rgb = self._run_single(
204+
processor,
205+
tensor_img,
206+
prompt,
207+
confidence_threshold,
208+
mask_blur,
209+
mask_offset,
210+
invert_output,
211+
background,
212+
background_color,
213+
)
214+
result_images.append(pil2tensor(img_pil))
215+
result_masks.append(mask_tensor)
216+
result_mask_images.append(mask_rgb)
217+
return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0)
218+
219+
220+
NODE_CLASS_MAPPINGS = {
221+
"SAM3Segment": SAM3Segment,
222+
}
223+
224+
NODE_DISPLAY_NAME_MAPPINGS = {
225+
"SAM3Segment": "SAM3 Segmentation (RMBG)",
226+
}

0 commit comments

Comments
 (0)